Skip to content
Snippets Groups Projects
aead.go 12.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • // Copyright 2024 The Go Authors. All rights reserved.
    // Use of this source code is governed by a BSD-style
    // license that can be found in the LICENSE file.
    
    package cryptotest
    
    import (
    	"bytes"
    	"crypto/cipher"
    	"fmt"
    	"testing"
    )
    
    var lengths = []int{0, 156, 8192, 8193, 8208}
    
    // MakeAEAD returns a cipher.AEAD instance.
    //
    // Multiple calls to MakeAEAD must return equivalent instances, so for example
    // the key must be fixed.
    type MakeAEAD func() (cipher.AEAD, error)
    
    // TestAEAD performs a set of tests on cipher.AEAD implementations, checking
    // the documented requirements of NonceSize, Overhead, Seal and Open.
    func TestAEAD(t *testing.T, mAEAD MakeAEAD) {
    	aead, err := mAEAD()
    	if err != nil {
    		t.Fatal(err)
    	}
    
    	t.Run("Roundtrip", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					rng := newRandReader(t)
    
    					nonce := make([]byte, aead.NonceSize())
    					rng.Read(nonce)
    
    					before, addData := make([]byte, adLen), make([]byte, ptLen)
    					rng.Read(before)
    					rng.Read(addData)
    
    					ciphertext := sealMsg(t, aead, nil, nonce, before, addData)
    					after := openWithoutError(t, aead, nil, nonce, ciphertext, addData)
    
    					if !bytes.Equal(after, before) {
    						t.Errorf("plaintext is different after a seal/open cycle; got %s, want %s", truncateHex(after), truncateHex(before))
    					}
    				})
    			}
    		}
    	})
    
    	t.Run("InputNotModified", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					t.Run("Seal", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						src, before := make([]byte, ptLen), make([]byte, ptLen)
    						rng.Read(src)
    						copy(before, src)
    
    						addData := make([]byte, adLen)
    						rng.Read(addData)
    
    						sealMsg(t, aead, nil, nonce, src, addData)
    						if !bytes.Equal(src, before) {
    							t.Errorf("Seal modified src; got %s, want %s", truncateHex(src), truncateHex(before))
    						}
    					})
    
    					t.Run("Open", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    						rng.Read(plaintext)
    						rng.Read(addData)
    
    						// Record the ciphertext that shouldn't be modified as the input of
    						// Open.
    						ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
    						before := make([]byte, len(ciphertext))
    						copy(before, ciphertext)
    
    						openWithoutError(t, aead, nil, nonce, ciphertext, addData)
    						if !bytes.Equal(ciphertext, before) {
    							t.Errorf("Open modified src; got %s, want %s", truncateHex(ciphertext), truncateHex(before))
    						}
    					})
    				})
    			}
    		}
    	})
    
    	t.Run("BufferOverlap", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			if ptLen <= 1 { // We need enough room for an overlap to occur.
    				continue
    			}
    			for _, adLen := range lengths {
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					t.Run("Seal", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						// Make a buffer that can hold a plaintext and ciphertext as we
    						// overlap their slices to check for panic on inexact overlaps.
    						ctLen := ptLen + aead.Overhead()
    						buff := make([]byte, ptLen+ctLen)
    						rng.Read(buff)
    
    						addData := make([]byte, adLen)
    						rng.Read(addData)
    
    						// Make plaintext and dst slices point to same array with inexact overlap.
    						plaintext := buff[:ptLen]
    						dst := buff[1:1] // Shift dst to not start at start of plaintext.
    						mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) })
    
    						// Only overlap on one byte
    						plaintext = buff[:ptLen]
    						dst = buff[ptLen-1 : ptLen-1]
    						mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) })
    					})
    
    					t.Run("Open", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						// Create a valid ciphertext to test Open with.
    						plaintext := make([]byte, ptLen)
    						rng.Read(plaintext)
    						addData := make([]byte, adLen)
    						rng.Read(addData)
    						validCT := sealMsg(t, aead, nil, nonce, plaintext, addData)
    
    						// Make a buffer that can hold a plaintext and ciphertext as we
    						// overlap their slices to check for panic on inexact overlaps.
    						buff := make([]byte, ptLen+len(validCT))
    
    						// Make ciphertext and dst slices point to same array with inexact overlap.
    						ciphertext := buff[:len(validCT)]
    						copy(ciphertext, validCT)
    						dst := buff[1:1] // Shift dst to not start at start of ciphertext.
    						mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) })
    
    						// Only overlap on one byte.
    						ciphertext = buff[:len(validCT)]
    						copy(ciphertext, validCT)
    						// Make sure it is the actual ciphertext being overlapped and not
    						// the hash digest which might be extracted/truncated in some
    						// implementations: Go one byte past the hash digest/tag and into
    						// the ciphertext.
    						beforeTag := len(validCT) - aead.Overhead()
    						dst = buff[beforeTag-1 : beforeTag-1]
    						mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) })
    					})
    				})
    			}
    		}
    	})
    
    	t.Run("AppendDst", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    
    					t.Run("Seal", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						shortBuff := []byte("a")
    						longBuff := make([]byte, 512)
    						rng.Read(longBuff)
    						prefixes := [][]byte{shortBuff, longBuff}
    
    						// Check each prefix gets appended to by Seal with altering them.
    						for _, prefix := range prefixes {
    							plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    							rng.Read(plaintext)
    							rng.Read(addData)
    							out := sealMsg(t, aead, prefix, nonce, plaintext, addData)
    
    							// Check that Seal didn't alter the prefix
    							if !bytes.Equal(out[0:len(prefix)], prefix) {
    								t.Errorf("Seal alters dst instead of appending; got %s, want %s", truncateHex(out[0:len(prefix)]), truncateHex(prefix))
    							}
    
    							ciphertext := out[len(prefix):]
    							// Check that the appended ciphertext wasn't affected by the prefix
    							if expectedCT := sealMsg(t, aead, nil, nonce, plaintext, addData); !bytes.Equal(ciphertext, expectedCT) {
    								t.Errorf("Seal behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(ciphertext), truncateHex(expectedCT))
    							}
    						}
    					})
    
    					t.Run("Open", func(t *testing.T) {
    						rng := newRandReader(t)
    
    						nonce := make([]byte, aead.NonceSize())
    						rng.Read(nonce)
    
    						shortBuff := []byte("a")
    						longBuff := make([]byte, 512)
    						rng.Read(longBuff)
    						prefixes := [][]byte{shortBuff, longBuff}
    
    						// Check each prefix gets appended to by Open with altering them.
    						for _, prefix := range prefixes {
    							before, addData := make([]byte, adLen), make([]byte, ptLen)
    							rng.Read(before)
    							rng.Read(addData)
    							ciphertext := sealMsg(t, aead, nil, nonce, before, addData)
    
    							out := openWithoutError(t, aead, prefix, nonce, ciphertext, addData)
    
    							// Check that Open didn't alter the prefix
    							if !bytes.Equal(out[0:len(prefix)], prefix) {
    								t.Errorf("Open alters dst instead of appending; got %s, want %s", truncateHex(out[0:len(prefix)]), truncateHex(prefix))
    							}
    
    							after := out[len(prefix):]
    							// Check that the appended plaintext wasn't affected by the prefix
    							if !bytes.Equal(after, before) {
    								t.Errorf("Open behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(after), truncateHex(before))
    							}
    						}
    					})
    				})
    			}
    		}
    	})
    
    	t.Run("WrongNonce", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					rng := newRandReader(t)
    
    					nonce := make([]byte, aead.NonceSize())
    					rng.Read(nonce)
    
    					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    					rng.Read(plaintext)
    					rng.Read(addData)
    
    					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
    
    					// Perturb the nonce and check for an error when Opening
    					alterNonce := make([]byte, aead.NonceSize())
    					copy(alterNonce, nonce)
    					alterNonce[len(alterNonce)-1] += 1
    					_, err := aead.Open(nil, alterNonce, ciphertext, addData)
    
    					if err == nil {
    						t.Errorf("Open did not error when given different nonce than Sealed with")
    					}
    				})
    			}
    		}
    	})
    
    	t.Run("WrongAddData", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    				if adLen == 0 {
    					continue
    				}
    
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					rng := newRandReader(t)
    
    					nonce := make([]byte, aead.NonceSize())
    					rng.Read(nonce)
    
    					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    					rng.Read(plaintext)
    					rng.Read(addData)
    
    					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
    
    					// Perturb the Additional Data and check for an error when Opening
    					alterAD := make([]byte, adLen)
    					copy(alterAD, addData)
    					alterAD[len(alterAD)-1] += 1
    					_, err := aead.Open(nil, nonce, ciphertext, alterAD)
    
    					if err == nil {
    						t.Errorf("Open did not error when given different Additional Data than Sealed with")
    					}
    				})
    			}
    		}
    	})
    
    	t.Run("WrongCiphertext", func(t *testing.T) {
    
    		// Test all combinations of plaintext and additional data lengths.
    		for _, ptLen := range lengths {
    			for _, adLen := range lengths {
    
    				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    					rng := newRandReader(t)
    
    					nonce := make([]byte, aead.NonceSize())
    					rng.Read(nonce)
    
    					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    					rng.Read(plaintext)
    					rng.Read(addData)
    
    					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
    
    					// Perturb the ciphertext and check for an error when Opening
    					alterCT := make([]byte, len(ciphertext))
    					copy(alterCT, ciphertext)
    					alterCT[len(alterCT)-1] += 1
    					_, err := aead.Open(nil, nonce, alterCT, addData)
    
    					if err == nil {
    						t.Errorf("Open did not error when given different ciphertext than was produced by Seal")
    					}
    				})
    			}
    		}
    	})
    }
    
    // Helper function to Seal a plaintext with additional data. Checks that
    // ciphertext isn't bigger than the plaintext length plus Overhead()
    func sealMsg(t *testing.T, aead cipher.AEAD, ciphertext, nonce, plaintext, addData []byte) []byte {
    	t.Helper()
    
    	initialLen := len(ciphertext)
    
    	ciphertext = aead.Seal(ciphertext, nonce, plaintext, addData)
    
    	lenCT := len(ciphertext) - initialLen
    
    	// Appended ciphertext shouldn't ever be longer than the length of the
    	// plaintext plus Overhead
    	if lenCT > len(plaintext)+aead.Overhead() {
    		t.Errorf("length of ciphertext from Seal exceeds length of plaintext by more than Overhead(); got %d, want <=%d", lenCT, len(plaintext)+aead.Overhead())
    	}
    
    	return ciphertext
    }
    
    // Helper function to Open and authenticate ciphertext. Checks that Open
    // doesn't error (assuming ciphertext was well-formed with corresponding nonce
    // and additional data).
    func openWithoutError(t *testing.T, aead cipher.AEAD, plaintext, nonce, ciphertext, addData []byte) []byte {
    	t.Helper()
    
    	plaintext, err := aead.Open(plaintext, nonce, ciphertext, addData)
    	if err != nil {
    		t.Fatalf("Open returned error on properly formed ciphertext; got \"%s\", want \"nil\"", err)
    	}
    
    	return plaintext
    }