From fa8ff1a46deb6c816304441ec6740ec112e19012 Mon Sep 17 00:00:00 2001
From: Roland Shoemaker <bracewell@google.com>
Date: Fri, 3 May 2024 09:21:39 -0400
Subject: [PATCH] [release-branch.go1.23] encoding/gob: cover missed cases when
 checking ignore depth

This change makes sure that we are properly checking the ignored field
recursion depth in decIgnoreOpFor consistently. This prevents stack
exhaustion when attempting to decode a message that contains an
extremely deeply nested struct which is ignored.

Thanks to Md Sakib Anwar of The Ohio State University (anwar.40@osu.edu)
for reporting this issue.

Updates #69139
Fixes #69145
Fixes CVE-2024-34156

Change-Id: Iacce06be95a5892b3064f1c40fcba2e2567862d6
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1440
Reviewed-by: Russ Cox <rsc@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
(cherry picked from commit 9f2ea73c5f2a7056b7da5d579a485a7216f4b20a)
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1581
Commit-Queue: Roland Shoemaker <bracewell@google.com>
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Reviewed-on: https://go-review.googlesource.com/c/go/+/611176
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Michael Pratt <mpratt@google.com>
TryBot-Bypass: Dmitri Shuralyov <dmitshur@google.com>
---
 src/encoding/gob/decode.go         | 19 +++++++++++--------
 src/encoding/gob/decoder.go        |  2 ++
 src/encoding/gob/gobencdec_test.go | 14 ++++++++++++++
 3 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/src/encoding/gob/decode.go b/src/encoding/gob/decode.go
index d178b2b2fb6..26b5f6d62b6 100644
--- a/src/encoding/gob/decode.go
+++ b/src/encoding/gob/decode.go
@@ -911,8 +911,11 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
 var maxIgnoreNestingDepth = 10000
 
 // decIgnoreOpFor returns the decoding op for a field that has no destination.
-func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp {
-	if depth > maxIgnoreNestingDepth {
+func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp {
+	// Track how deep we've recursed trying to skip nested ignored fields.
+	dec.ignoreDepth++
+	defer func() { dec.ignoreDepth-- }()
+	if dec.ignoreDepth > maxIgnoreNestingDepth {
 		error_(errors.New("invalid nesting depth"))
 	}
 	// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
@@ -938,7 +941,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
 			errorf("bad data: undefined type %s", wireId.string())
 		case wire.ArrayT != nil:
 			elemId := wire.ArrayT.Elem
-			elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+			elemOp := dec.decIgnoreOpFor(elemId, inProgress)
 			op = func(i *decInstr, state *decoderState, value reflect.Value) {
 				state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
 			}
@@ -946,15 +949,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
 		case wire.MapT != nil:
 			keyId := dec.wireType[wireId].MapT.Key
 			elemId := dec.wireType[wireId].MapT.Elem
-			keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1)
-			elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+			keyOp := dec.decIgnoreOpFor(keyId, inProgress)
+			elemOp := dec.decIgnoreOpFor(elemId, inProgress)
 			op = func(i *decInstr, state *decoderState, value reflect.Value) {
 				state.dec.ignoreMap(state, *keyOp, *elemOp)
 			}
 
 		case wire.SliceT != nil:
 			elemId := wire.SliceT.Elem
-			elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
+			elemOp := dec.decIgnoreOpFor(elemId, inProgress)
 			op = func(i *decInstr, state *decoderState, value reflect.Value) {
 				state.dec.ignoreSlice(state, *elemOp)
 			}
@@ -1115,7 +1118,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *de
 func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
 	engine := new(decEngine)
 	engine.instr = make([]decInstr, 1) // one item
-	op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0)
+	op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp))
 	ovfl := overflow(dec.typeString(remoteId))
 	engine.instr[0] = decInstr{*op, 0, nil, ovfl}
 	engine.numInstr = 1
@@ -1160,7 +1163,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn
 		localField, present := srt.FieldByName(wireField.Name)
 		// TODO(r): anonymous names
 		if !present || !isExported(wireField.Name) {
-			op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0)
+			op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp))
 			engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
 			continue
 		}
diff --git a/src/encoding/gob/decoder.go b/src/encoding/gob/decoder.go
index c4b60880130..eae307838e2 100644
--- a/src/encoding/gob/decoder.go
+++ b/src/encoding/gob/decoder.go
@@ -35,6 +35,8 @@ type Decoder struct {
 	freeList     *decoderState                           // list of free decoderStates; avoids reallocation
 	countBuf     []byte                                  // used for decoding integers while parsing messages
 	err          error
+	// ignoreDepth tracks the depth of recursively parsed ignored fields
+	ignoreDepth int
 }
 
 // NewDecoder returns a new decoder that reads from the [io.Reader].
diff --git a/src/encoding/gob/gobencdec_test.go b/src/encoding/gob/gobencdec_test.go
index ae806fc39a2..d30e622aa2c 100644
--- a/src/encoding/gob/gobencdec_test.go
+++ b/src/encoding/gob/gobencdec_test.go
@@ -806,6 +806,8 @@ func TestIgnoreDepthLimit(t *testing.T) {
 	defer func() { maxIgnoreNestingDepth = oldNestingDepth }()
 	b := new(bytes.Buffer)
 	enc := NewEncoder(b)
+
+	// Nested slice
 	typ := reflect.TypeFor[int]()
 	nested := reflect.ArrayOf(1, typ)
 	for i := 0; i < 100; i++ {
@@ -819,4 +821,16 @@ func TestIgnoreDepthLimit(t *testing.T) {
 	if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
 		t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
 	}
+
+	// Nested struct
+	nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: typ}})
+	for i := 0; i < 100; i++ {
+		nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})
+	}
+	badStruct = reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}}))
+	enc.Encode(badStruct.Interface())
+	dec = NewDecoder(b)
+	if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
+		t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
+	}
 }
-- 
GitLab