From f5a89dff67ae00bfc70fbfccc1b1cc044e565b50 Mon Sep 17 00:00:00 2001
From: Filippo Valsorda <filippo@golang.org>
Date: Mon, 6 Jan 2025 18:52:35 +0100
Subject: [PATCH] crypto: fix fips140=only detection of SHA-3

Both fips140only and the service indicator checks in
crypto/internal/fips140/... expect to type assert to
crypto/internal/fips140/{sha256,sha512,sha3}.Digest.

However, crypto/sha3 returns a wrapper concrete type around sha3.Digest.

Add a new fips140hash.Unwrap function to turn the wrapper into the
underlying sha3.Digest, and use it consistently before calling into
fips140only or the FIPS 140-3 module.

In crypto/rsa, also made the fips140only checks apply consistently after
the Go+BoringCrypto shims, so we can instantiate the hash, and avoid
having to wrap the New function. Note that fips140=only is incompatible
with Go+BoringCrypto.

Fixes #70879

Change-Id: I6a6a4656ec55c3e13f6cbfadb9cf89c0f9183bdc
Reviewed-on: https://go-review.googlesource.com/c/go/+/640855
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 src/crypto/ecdsa/ecdsa.go               |   6 +-
 src/crypto/hkdf/hkdf.go                 |  22 ++--
 src/crypto/hmac/hmac.go                 |   2 +
 src/crypto/internal/fips140hash/hash.go |  34 ++++++
 src/crypto/pbkdf2/pbkdf2.go             |   6 +-
 src/crypto/rsa/fips.go                  | 133 +++++++++++++-----------
 src/crypto/sha3/sha3.go                 |   6 ++
 src/go/build/deps_test.go               |   3 +-
 8 files changed, 139 insertions(+), 73 deletions(-)
 create mode 100644 src/crypto/internal/fips140hash/hash.go

diff --git a/src/crypto/ecdsa/ecdsa.go b/src/crypto/ecdsa/ecdsa.go
index d9ebe56ef00..cb308b41e9d 100644
--- a/src/crypto/ecdsa/ecdsa.go
+++ b/src/crypto/ecdsa/ecdsa.go
@@ -23,6 +23,7 @@ import (
 	"crypto/internal/boring"
 	"crypto/internal/boring/bbig"
 	"crypto/internal/fips140/ecdsa"
+	"crypto/internal/fips140hash"
 	"crypto/internal/fips140only"
 	"crypto/internal/randutil"
 	"crypto/sha512"
@@ -281,10 +282,11 @@ func signFIPSDeterministic[P ecdsa.Point[P]](c *ecdsa.Curve[P], hashFunc crypto.
 	if err != nil {
 		return nil, err
 	}
-	if fips140only.Enabled && !fips140only.ApprovedHash(hashFunc.New()) {
+	h := fips140hash.UnwrapNew(hashFunc.New)
+	if fips140only.Enabled && !fips140only.ApprovedHash(h()) {
 		return nil, errors.New("crypto/ecdsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
 	}
-	sig, err := ecdsa.SignDeterministic(c, hashFunc.New, k, hash)
+	sig, err := ecdsa.SignDeterministic(c, h, k, hash)
 	if err != nil {
 		return nil, err
 	}
diff --git a/src/crypto/hkdf/hkdf.go b/src/crypto/hkdf/hkdf.go
index 7cfbe2c60de..6b02522866d 100644
--- a/src/crypto/hkdf/hkdf.go
+++ b/src/crypto/hkdf/hkdf.go
@@ -12,6 +12,7 @@ package hkdf
 
 import (
 	"crypto/internal/fips140/hkdf"
+	"crypto/internal/fips140hash"
 	"crypto/internal/fips140only"
 	"errors"
 	"hash"
@@ -24,10 +25,11 @@ import (
 // Expand invocations and different context values. Most common scenarios,
 // including the generation of multiple keys, should use [Key] instead.
 func Extract[H hash.Hash](h func() H, secret, salt []byte) ([]byte, error) {
-	if err := checkFIPS140Only(h, secret); err != nil {
+	fh := fips140hash.UnwrapNew(h)
+	if err := checkFIPS140Only(fh, secret); err != nil {
 		return nil, err
 	}
-	return hkdf.Extract(h, secret, salt), nil
+	return hkdf.Extract(fh, secret, salt), nil
 }
 
 // Expand derives a key from the given hash, key, and optional context info,
@@ -38,35 +40,37 @@ func Extract[H hash.Hash](h func() H, secret, salt []byte) ([]byte, error) {
 // random or pseudorandom cryptographically strong key. See RFC 5869, Section
 // 3.3. Most common scenarios will want to use [Key] instead.
 func Expand[H hash.Hash](h func() H, pseudorandomKey []byte, info string, keyLength int) ([]byte, error) {
-	if err := checkFIPS140Only(h, pseudorandomKey); err != nil {
+	fh := fips140hash.UnwrapNew(h)
+	if err := checkFIPS140Only(fh, pseudorandomKey); err != nil {
 		return nil, err
 	}
 
-	limit := h().Size() * 255
+	limit := fh().Size() * 255
 	if keyLength > limit {
 		return nil, errors.New("hkdf: requested key length too large")
 	}
 
-	return hkdf.Expand(h, pseudorandomKey, info, keyLength), nil
+	return hkdf.Expand(fh, pseudorandomKey, info, keyLength), nil
 }
 
 // Key derives a key from the given hash, secret, salt and context info,
 // returning a []byte of length keyLength that can be used as cryptographic key.
 // Salt and info can be nil.
 func Key[Hash hash.Hash](h func() Hash, secret, salt []byte, info string, keyLength int) ([]byte, error) {
-	if err := checkFIPS140Only(h, secret); err != nil {
+	fh := fips140hash.UnwrapNew(h)
+	if err := checkFIPS140Only(fh, secret); err != nil {
 		return nil, err
 	}
 
-	limit := h().Size() * 255
+	limit := fh().Size() * 255
 	if keyLength > limit {
 		return nil, errors.New("hkdf: requested key length too large")
 	}
 
-	return hkdf.Key(h, secret, salt, info, keyLength), nil
+	return hkdf.Key(fh, secret, salt, info, keyLength), nil
 }
 
-func checkFIPS140Only[H hash.Hash](h func() H, key []byte) error {
+func checkFIPS140Only[Hash hash.Hash](h func() Hash, key []byte) error {
 	if !fips140only.Enabled {
 		return nil
 	}
diff --git a/src/crypto/hmac/hmac.go b/src/crypto/hmac/hmac.go
index 72f5a4abea9..554c8c9b789 100644
--- a/src/crypto/hmac/hmac.go
+++ b/src/crypto/hmac/hmac.go
@@ -24,6 +24,7 @@ package hmac
 import (
 	"crypto/internal/boring"
 	"crypto/internal/fips140/hmac"
+	"crypto/internal/fips140hash"
 	"crypto/internal/fips140only"
 	"crypto/subtle"
 	"hash"
@@ -43,6 +44,7 @@ func New(h func() hash.Hash, key []byte) hash.Hash {
 		}
 		// BoringCrypto did not recognize h, so fall through to standard Go code.
 	}
+	h = fips140hash.UnwrapNew(h)
 	if fips140only.Enabled {
 		if len(key) < 112/8 {
 			panic("crypto/hmac: use of keys shorter than 112 bits is not allowed in FIPS 140-only mode")
diff --git a/src/crypto/internal/fips140hash/hash.go b/src/crypto/internal/fips140hash/hash.go
new file mode 100644
index 00000000000..6d67ee8b342
--- /dev/null
+++ b/src/crypto/internal/fips140hash/hash.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fips140hash
+
+import (
+	fsha3 "crypto/internal/fips140/sha3"
+	"crypto/sha3"
+	"hash"
+	_ "unsafe"
+)
+
+//go:linkname sha3Unwrap
+func sha3Unwrap(*sha3.SHA3) *fsha3.Digest
+
+// Unwrap returns h, or a crypto/internal/fips140 inner implementation of h.
+//
+// The return value can be type asserted to one of
+// [crypto/internal/fips140/sha256.Digest],
+// [crypto/internal/fips140/sha512.Digest], or
+// [crypto/internal/fips140/sha3.Digest] if it is a FIPS 140-3 approved hash.
+func Unwrap(h hash.Hash) hash.Hash {
+	if sha3, ok := h.(*sha3.SHA3); ok {
+		return sha3Unwrap(sha3)
+	}
+	return h
+}
+
+// UnwrapNew returns a function that calls newHash and applies [Unwrap] to the
+// return value.
+func UnwrapNew[Hash hash.Hash](newHash func() Hash) func() hash.Hash {
+	return func() hash.Hash { return Unwrap(newHash()) }
+}
diff --git a/src/crypto/pbkdf2/pbkdf2.go b/src/crypto/pbkdf2/pbkdf2.go
index d40daab5e5b..271d2b03312 100644
--- a/src/crypto/pbkdf2/pbkdf2.go
+++ b/src/crypto/pbkdf2/pbkdf2.go
@@ -12,6 +12,7 @@ package pbkdf2
 
 import (
 	"crypto/internal/fips140/pbkdf2"
+	"crypto/internal/fips140hash"
 	"crypto/internal/fips140only"
 	"errors"
 	"hash"
@@ -34,6 +35,7 @@ import (
 // Using a higher iteration count will increase the cost of an exhaustive
 // search but will also make derivation proportionally slower.
 func Key[Hash hash.Hash](h func() Hash, password string, salt []byte, iter, keyLength int) ([]byte, error) {
+	fh := fips140hash.UnwrapNew(h)
 	if fips140only.Enabled {
 		if keyLength < 112/8 {
 			return nil, errors.New("crypto/pbkdf2: use of keys shorter than 112 bits is not allowed in FIPS 140-only mode")
@@ -41,9 +43,9 @@ func Key[Hash hash.Hash](h func() Hash, password string, salt []byte, iter, keyL
 		if len(salt) < 128/8 {
 			return nil, errors.New("crypto/pbkdf2: use of salts shorter than 128 bits is not allowed in FIPS 140-only mode")
 		}
-		if !fips140only.ApprovedHash(h()) {
+		if !fips140only.ApprovedHash(fh()) {
 			return nil, errors.New("crypto/pbkdf2: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
 		}
 	}
-	return pbkdf2.Key(h, password, salt, iter, keyLength)
+	return pbkdf2.Key(fh, password, salt, iter, keyLength)
 }
diff --git a/src/crypto/rsa/fips.go b/src/crypto/rsa/fips.go
index 347775df160..8373c125ae3 100644
--- a/src/crypto/rsa/fips.go
+++ b/src/crypto/rsa/fips.go
@@ -8,6 +8,7 @@ import (
 	"crypto"
 	"crypto/internal/boring"
 	"crypto/internal/fips140/rsa"
+	"crypto/internal/fips140hash"
 	"crypto/internal/fips140only"
 	"errors"
 	"hash"
@@ -64,21 +65,11 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte,
 	if err := checkPublicKeySize(&priv.PublicKey); err != nil {
 		return nil, err
 	}
-	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
-		return nil, err
-	}
 
 	if opts != nil && opts.Hash != 0 {
 		hash = opts.Hash
 	}
 
-	if fips140only.Enabled && !fips140only.ApprovedHash(hash.New()) {
-		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-	}
-	if fips140only.Enabled && !fips140only.ApprovedRandomReader(rand) {
-		return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
-	}
-
 	if boring.Enabled && rand == boring.RandReader {
 		bkey, err := boringPrivateKey(priv)
 		if err != nil {
@@ -88,14 +79,25 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte,
 	}
 	boring.UnreachableExceptTests()
 
+	h := fips140hash.Unwrap(hash.New())
+
+	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
+		return nil, err
+	}
+	if fips140only.Enabled && !fips140only.ApprovedHash(h) {
+		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+	}
+	if fips140only.Enabled && !fips140only.ApprovedRandomReader(rand) {
+		return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
+	}
+
 	k, err := fipsPrivateKey(priv)
 	if err != nil {
 		return nil, err
 	}
-	h := hash.New()
 
 	saltLength := opts.saltLength()
-	if fips140only.Enabled && saltLength > hash.Size() {
+	if fips140only.Enabled && saltLength > h.Size() {
 		return nil, errors.New("crypto/rsa: use of PSS salt longer than the hash is not allowed in FIPS 140-only mode")
 	}
 	switch saltLength {
@@ -105,7 +107,7 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte,
 			return nil, fipsError(err)
 		}
 	case PSSSaltLengthEqualsHash:
-		saltLength = hash.Size()
+		saltLength = h.Size()
 	default:
 		// If we get here saltLength is either > 0 or < -1, in the
 		// latter case we fail out.
@@ -130,12 +132,6 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
 	if err := checkPublicKeySize(pub); err != nil {
 		return err
 	}
-	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
-		return err
-	}
-	if fips140only.Enabled && !fips140only.ApprovedHash(hash.New()) {
-		return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-	}
 
 	if boring.Enabled {
 		bkey, err := boringPublicKey(pub)
@@ -148,22 +144,31 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
 		return nil
 	}
 
+	h := fips140hash.Unwrap(hash.New())
+
+	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
+		return err
+	}
+	if fips140only.Enabled && !fips140only.ApprovedHash(h) {
+		return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+	}
+
 	k, err := fipsPublicKey(pub)
 	if err != nil {
 		return err
 	}
 
 	saltLength := opts.saltLength()
-	if fips140only.Enabled && saltLength > hash.Size() {
+	if fips140only.Enabled && saltLength > h.Size() {
 		return errors.New("crypto/rsa: use of PSS salt longer than the hash is not allowed in FIPS 140-only mode")
 	}
 	switch saltLength {
 	case PSSSaltLengthAuto:
-		return fipsError(rsa.VerifyPSS(k, hash.New(), digest, sig))
+		return fipsError(rsa.VerifyPSS(k, h, digest, sig))
 	case PSSSaltLengthEqualsHash:
-		return fipsError(rsa.VerifyPSSWithSaltLength(k, hash.New(), digest, sig, hash.Size()))
+		return fipsError(rsa.VerifyPSSWithSaltLength(k, h, digest, sig, h.Size()))
 	default:
-		return fipsError(rsa.VerifyPSSWithSaltLength(k, hash.New(), digest, sig, saltLength))
+		return fipsError(rsa.VerifyPSSWithSaltLength(k, h, digest, sig, saltLength))
 	}
 }
 
@@ -189,15 +194,6 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
 	if err := checkPublicKeySize(pub); err != nil {
 		return nil, err
 	}
-	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
-		return nil, err
-	}
-	if fips140only.Enabled && !fips140only.ApprovedHash(hash) {
-		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-	}
-	if fips140only.Enabled && !fips140only.ApprovedRandomReader(random) {
-		return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
-	}
 
 	defer hash.Reset()
 
@@ -215,6 +211,18 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
 	}
 	boring.UnreachableExceptTests()
 
+	hash = fips140hash.Unwrap(hash)
+
+	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
+		return nil, err
+	}
+	if fips140only.Enabled && !fips140only.ApprovedHash(hash) {
+		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+	}
+	if fips140only.Enabled && !fips140only.ApprovedRandomReader(random) {
+		return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
+	}
+
 	k, err := fipsPublicKey(pub)
 	if err != nil {
 		return nil, err
@@ -241,14 +249,6 @@ func decryptOAEP(hash, mgfHash hash.Hash, priv *PrivateKey, ciphertext []byte, l
 	if err := checkPublicKeySize(&priv.PublicKey); err != nil {
 		return nil, err
 	}
-	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
-		return nil, err
-	}
-	if fips140only.Enabled {
-		if !fips140only.ApprovedHash(hash) || !fips140only.ApprovedHash(mgfHash) {
-			return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-		}
-	}
 
 	if boring.Enabled {
 		k := priv.Size()
@@ -267,6 +267,18 @@ func decryptOAEP(hash, mgfHash hash.Hash, priv *PrivateKey, ciphertext []byte, l
 		return out, nil
 	}
 
+	hash = fips140hash.Unwrap(hash)
+	mgfHash = fips140hash.Unwrap(mgfHash)
+
+	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
+		return nil, err
+	}
+	if fips140only.Enabled {
+		if !fips140only.ApprovedHash(hash) || !fips140only.ApprovedHash(mgfHash) {
+			return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+		}
+	}
+
 	k, err := fipsPrivateKey(priv)
 	if err != nil {
 		return nil, err
@@ -299,12 +311,6 @@ func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed [
 	if err := checkPublicKeySize(&priv.PublicKey); err != nil {
 		return nil, err
 	}
-	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
-		return nil, err
-	}
-	if fips140only.Enabled && !fips140only.ApprovedHash(hash.New()) {
-		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-	}
 
 	if boring.Enabled {
 		bkey, err := boringPrivateKey(priv)
@@ -314,6 +320,13 @@ func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed [
 		return boring.SignRSAPKCS1v15(bkey, hash, hashed)
 	}
 
+	if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
+		return nil, err
+	}
+	if fips140only.Enabled && !fips140only.ApprovedHash(fips140hash.Unwrap(hash.New())) {
+		return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+	}
+
 	k, err := fipsPrivateKey(priv)
 	if err != nil {
 		return nil, err
@@ -330,15 +343,17 @@ func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed [
 // The inputs are not considered confidential, and may leak through timing side
 // channels, or if an attacker has control of part of the inputs.
 func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) error {
-	if err := checkPublicKeySize(pub); err != nil {
-		return err
+	var hashName string
+	if hash != crypto.Hash(0) {
+		if len(hashed) != hash.Size() {
+			return errors.New("crypto/rsa: input must be hashed message")
+		}
+		hashName = hash.String()
 	}
-	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
+
+	if err := checkPublicKeySize(pub); err != nil {
 		return err
 	}
-	if fips140only.Enabled && !fips140only.ApprovedHash(hash.New()) {
-		return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
-	}
 
 	if boring.Enabled {
 		bkey, err := boringPublicKey(pub)
@@ -351,17 +366,17 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
 		return nil
 	}
 
+	if err := checkFIPS140OnlyPublicKey(pub); err != nil {
+		return err
+	}
+	if fips140only.Enabled && !fips140only.ApprovedHash(fips140hash.Unwrap(hash.New())) {
+		return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
+	}
+
 	k, err := fipsPublicKey(pub)
 	if err != nil {
 		return err
 	}
-	var hashName string
-	if hash != crypto.Hash(0) {
-		if len(hashed) != hash.Size() {
-			return errors.New("crypto/rsa: input must be hashed message")
-		}
-		hashName = hash.String()
-	}
 	return fipsError(rsa.VerifyPKCS1v15(k, hashName, hashed, sig))
 }
 
diff --git a/src/crypto/sha3/sha3.go b/src/crypto/sha3/sha3.go
index 0f4d7ed4370..a6c5ae55f1f 100644
--- a/src/crypto/sha3/sha3.go
+++ b/src/crypto/sha3/sha3.go
@@ -10,6 +10,7 @@ import (
 	"crypto"
 	"crypto/internal/fips140/sha3"
 	"hash"
+	_ "unsafe"
 )
 
 func init() {
@@ -100,6 +101,11 @@ type SHA3 struct {
 	s sha3.Digest
 }
 
+//go:linkname fips140hash_sha3Unwrap crypto/internal/fips140hash.sha3Unwrap
+func fips140hash_sha3Unwrap(sha3 *SHA3) *sha3.Digest {
+	return &sha3.s
+}
+
 // New224 creates a new SHA3-224 hash.
 func New224() *SHA3 {
 	return &SHA3{*sha3.New224()}
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index a62a5173b9c..e3e01077c18 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -510,6 +510,8 @@ var depsRules = `
 	< crypto/internal/fips140only
 	< crypto
 	< crypto/subtle
+	< crypto/sha3
+	< crypto/internal/fips140hash
 	< crypto/cipher
 	< crypto/internal/boring
 	< crypto/boring
@@ -520,7 +522,6 @@ var depsRules = `
 	  crypto/sha1,
 	  crypto/sha256,
 	  crypto/sha512,
-	  crypto/sha3,
 	  crypto/hmac,
 	  crypto/hkdf,
 	  crypto/pbkdf2,
-- 
GitLab