From ba7b8ca336123017e43a2ab3310fd4a82122ef9d Mon Sep 17 00:00:00 2001
From: Achille Roussel <achille.roussel@gmail.com>
Date: Thu, 21 Dec 2023 16:54:24 -0800
Subject: [PATCH] iter: reduce memory footprint of iter.Pull functions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The implementation of iter.Pull and iter.Pull2 functions is based on
closures and sharing local state, which results in one heap allocation
for each captured variable.

The number of heap allocations can be reduced by grouping the state
shared between closures in a struct, allowing the compiler to allocate
all local variables in a single heap region instead of creating
individual heap objects for each variable.

This approach can sometimes have downsides when it couples unrelated
objects in a single memory region, preventing the garbage collector from
reclaiming unused memory. While technically only a subset of the local
state is shared between the next and stop functions, it seems unlikely
that retaining the rest of the state until stop is reclaimed would be
problematic in practice, since the two closures would often have very
similar lifetimes.

The change also reduces the total memory footprint due to alignment
rules, the two booleans can be packed in memory and sometimes can even
exist within the padding space of the v value. There is also less
metadata needed for the garbage collector to track each individual heap
allocation.

goos: darwin
goarch: arm64
pkg: iter
cpu: Apple M2 Pro
         │ /tmp/bench.old │           /tmp/bench.new            │
         │     sec/op     │   sec/op     vs base                │
Pull-12       218.6n ± 7%   146.1n ± 0%  -33.19% (p=0.000 n=10)
Pull2-12      239.8n ± 5%   155.0n ± 5%  -35.36% (p=0.000 n=10)
geomean       229.0n        150.5n       -34.28%

         │ /tmp/bench.old │           /tmp/bench.new           │
         │      B/op      │    B/op     vs base                │
Pull-12        288.0 ± 0%   176.0 ± 0%  -38.89% (p=0.000 n=10)
Pull2-12       312.0 ± 0%   176.0 ± 0%  -43.59% (p=0.000 n=10)
geomean        299.8        176.0       -41.29%

         │ /tmp/bench.old │           /tmp/bench.new           │
         │   allocs/op    │ allocs/op   vs base                │
Pull-12       11.000 ± 0%   5.000 ± 0%  -54.55% (p=0.000 n=10)
Pull2-12      12.000 ± 0%   5.000 ± 0%  -58.33% (p=0.000 n=10)
geomean        11.49        5.000       -56.48%

Change-Id: Iccbe233e8ae11066087ffa4781b66489d0d410a7
Reviewed-on: https://go-review.googlesource.com/c/go/+/552375
Reviewed-by: Sean Liao <sean@liao.dev>
Auto-Submit: Sean Liao <sean@liao.dev>
Reviewed-by: qiu laidongfeng2 <2645477756@qq.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 src/iter/iter.go      | 152 +++++++++++++++++++++---------------------
 src/iter/pull_test.go |  16 +++++
 2 files changed, 92 insertions(+), 76 deletions(-)

diff --git a/src/iter/iter.go b/src/iter/iter.go
index e765378ef25..4d408e5e77c 100644
--- a/src/iter/iter.go
+++ b/src/iter/iter.go
@@ -257,91 +257,91 @@ func coroswitch(*coro)
 // If the iterator panics during a call to next (or stop),
 // then next (or stop) itself panics with the same value.
 func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
-	var (
+	var pull struct {
 		v          V
 		ok         bool
 		done       bool
 		yieldNext  bool
+		seqDone    bool // to detect Goexit
 		racer      int
 		panicValue any
-		seqDone    bool // to detect Goexit
-	)
+	}
 	c := newcoro(func(c *coro) {
-		race.Acquire(unsafe.Pointer(&racer))
-		if done {
-			race.Release(unsafe.Pointer(&racer))
+		race.Acquire(unsafe.Pointer(&pull.racer))
+		if pull.done {
+			race.Release(unsafe.Pointer(&pull.racer))
 			return
 		}
 		yield := func(v1 V) bool {
-			if done {
+			if pull.done {
 				return false
 			}
-			if !yieldNext {
+			if !pull.yieldNext {
 				panic("iter.Pull: yield called again before next")
 			}
-			yieldNext = false
-			v, ok = v1, true
-			race.Release(unsafe.Pointer(&racer))
+			pull.yieldNext = false
+			pull.v, pull.ok = v1, true
+			race.Release(unsafe.Pointer(&pull.racer))
 			coroswitch(c)
-			race.Acquire(unsafe.Pointer(&racer))
-			return !done
+			race.Acquire(unsafe.Pointer(&pull.racer))
+			return !pull.done
 		}
 		// Recover and propagate panics from seq.
 		defer func() {
 			if p := recover(); p != nil {
-				panicValue = p
-			} else if !seqDone {
-				panicValue = goexitPanicValue
+				pull.panicValue = p
+			} else if !pull.seqDone {
+				pull.panicValue = goexitPanicValue
 			}
-			done = true // Invalidate iterator
-			race.Release(unsafe.Pointer(&racer))
+			pull.done = true // Invalidate iterator
+			race.Release(unsafe.Pointer(&pull.racer))
 		}()
 		seq(yield)
 		var v0 V
-		v, ok = v0, false
-		seqDone = true
+		pull.v, pull.ok = v0, false
+		pull.seqDone = true
 	})
 	next = func() (v1 V, ok1 bool) {
-		race.Write(unsafe.Pointer(&racer)) // detect races
+		race.Write(unsafe.Pointer(&pull.racer)) // detect races
 
-		if done {
+		if pull.done {
 			return
 		}
-		if yieldNext {
+		if pull.yieldNext {
 			panic("iter.Pull: next called again before yield")
 		}
-		yieldNext = true
-		race.Release(unsafe.Pointer(&racer))
+		pull.yieldNext = true
+		race.Release(unsafe.Pointer(&pull.racer))
 		coroswitch(c)
-		race.Acquire(unsafe.Pointer(&racer))
+		race.Acquire(unsafe.Pointer(&pull.racer))
 
 		// Propagate panics and goexits from seq.
-		if panicValue != nil {
-			if panicValue == goexitPanicValue {
+		if pull.panicValue != nil {
+			if pull.panicValue == goexitPanicValue {
 				// Propagate runtime.Goexit from seq.
 				runtime.Goexit()
 			} else {
-				panic(panicValue)
+				panic(pull.panicValue)
 			}
 		}
-		return v, ok
+		return pull.v, pull.ok
 	}
 	stop = func() {
-		race.Write(unsafe.Pointer(&racer)) // detect races
+		race.Write(unsafe.Pointer(&pull.racer)) // detect races
 
-		if !done {
-			done = true
-			race.Release(unsafe.Pointer(&racer))
+		if !pull.done {
+			pull.done = true
+			race.Release(unsafe.Pointer(&pull.racer))
 			coroswitch(c)
-			race.Acquire(unsafe.Pointer(&racer))
+			race.Acquire(unsafe.Pointer(&pull.racer))
 
 			// Propagate panics and goexits from seq.
-			if panicValue != nil {
-				if panicValue == goexitPanicValue {
+			if pull.panicValue != nil {
+				if pull.panicValue == goexitPanicValue {
 					// Propagate runtime.Goexit from seq.
 					runtime.Goexit()
 				} else {
-					panic(panicValue)
+					panic(pull.panicValue)
 				}
 			}
 		}
@@ -372,93 +372,93 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
 // If the iterator panics during a call to next (or stop),
 // then next (or stop) itself panics with the same value.
 func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
-	var (
+	var pull struct {
 		k          K
 		v          V
 		ok         bool
 		done       bool
 		yieldNext  bool
+		seqDone    bool
 		racer      int
 		panicValue any
-		seqDone    bool
-	)
+	}
 	c := newcoro(func(c *coro) {
-		race.Acquire(unsafe.Pointer(&racer))
-		if done {
-			race.Release(unsafe.Pointer(&racer))
+		race.Acquire(unsafe.Pointer(&pull.racer))
+		if pull.done {
+			race.Release(unsafe.Pointer(&pull.racer))
 			return
 		}
 		yield := func(k1 K, v1 V) bool {
-			if done {
+			if pull.done {
 				return false
 			}
-			if !yieldNext {
+			if !pull.yieldNext {
 				panic("iter.Pull2: yield called again before next")
 			}
-			yieldNext = false
-			k, v, ok = k1, v1, true
-			race.Release(unsafe.Pointer(&racer))
+			pull.yieldNext = false
+			pull.k, pull.v, pull.ok = k1, v1, true
+			race.Release(unsafe.Pointer(&pull.racer))
 			coroswitch(c)
-			race.Acquire(unsafe.Pointer(&racer))
-			return !done
+			race.Acquire(unsafe.Pointer(&pull.racer))
+			return !pull.done
 		}
 		// Recover and propagate panics from seq.
 		defer func() {
 			if p := recover(); p != nil {
-				panicValue = p
-			} else if !seqDone {
-				panicValue = goexitPanicValue
+				pull.panicValue = p
+			} else if !pull.seqDone {
+				pull.panicValue = goexitPanicValue
 			}
-			done = true // Invalidate iterator.
-			race.Release(unsafe.Pointer(&racer))
+			pull.done = true // Invalidate iterator.
+			race.Release(unsafe.Pointer(&pull.racer))
 		}()
 		seq(yield)
 		var k0 K
 		var v0 V
-		k, v, ok = k0, v0, false
-		seqDone = true
+		pull.k, pull.v, pull.ok = k0, v0, false
+		pull.seqDone = true
 	})
 	next = func() (k1 K, v1 V, ok1 bool) {
-		race.Write(unsafe.Pointer(&racer)) // detect races
+		race.Write(unsafe.Pointer(&pull.racer)) // detect races
 
-		if done {
+		if pull.done {
 			return
 		}
-		if yieldNext {
+		if pull.yieldNext {
 			panic("iter.Pull2: next called again before yield")
 		}
-		yieldNext = true
-		race.Release(unsafe.Pointer(&racer))
+		pull.yieldNext = true
+		race.Release(unsafe.Pointer(&pull.racer))
 		coroswitch(c)
-		race.Acquire(unsafe.Pointer(&racer))
+		race.Acquire(unsafe.Pointer(&pull.racer))
 
 		// Propagate panics and goexits from seq.
-		if panicValue != nil {
-			if panicValue == goexitPanicValue {
+		if pull.panicValue != nil {
+			if pull.panicValue == goexitPanicValue {
 				// Propagate runtime.Goexit from seq.
 				runtime.Goexit()
 			} else {
-				panic(panicValue)
+				panic(pull.panicValue)
 			}
 		}
-		return k, v, ok
+		return pull.k, pull.v, pull.ok
 	}
 	stop = func() {
-		race.Write(unsafe.Pointer(&racer)) // detect races
+		race.Write(unsafe.Pointer(&pull.racer)) // detect races
 
-		if !done {
-			done = true
-			race.Release(unsafe.Pointer(&racer))
+		if !pull.done {
+			pull.done = true
+			race.Release(unsafe.Pointer(&pull.racer))
 			coroswitch(c)
-			race.Acquire(unsafe.Pointer(&racer))
+			race.Acquire(unsafe.Pointer(&pull.racer))
 
 			// Propagate panics and goexits from seq.
-			if panicValue != nil {
-				if panicValue == goexitPanicValue {
+			if pull.panicValue != nil {
+				if pull.panicValue == goexitPanicValue {
 					// Propagate runtime.Goexit from seq.
 					runtime.Goexit()
 				} else {
-					panic(panicValue)
+					panic(pull.panicValue)
 				}
 			}
 		}
diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go
index e9e3bdadca0..c66e20897bd 100644
--- a/src/iter/pull_test.go
+++ b/src/iter/pull_test.go
@@ -491,3 +491,19 @@ func TestPull2ImmediateStop(t *testing.T) {
 		t.Fatal("next returned true after iterator was stopped")
 	}
 }
+
+func BenchmarkPull(b *testing.B) {
+	seq := count(1)
+	for range b.N {
+		_, stop := Pull(seq)
+		stop()
+	}
+}
+
+func BenchmarkPull2(b *testing.B) {
+	seq := squares(1)
+	for range b.N {
+		_, stop := Pull2(seq)
+		stop()
+	}
+}
-- 
GitLab