Skip to content
Snippets Groups Projects
Commit 46f5726d authored by Andy Lindeman's avatar Andy Lindeman
Browse files

Use oidc.Verifier to verify tokens

parent 157c359f
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,6 @@ package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
......@@ -15,6 +14,7 @@ import (
"sync"
"time"
oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
......@@ -23,10 +23,6 @@ import (
"github.com/dexidp/dex/storage"
)
var (
errTokenExpired = errors.New("token has expired")
)
// newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
......@@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
}
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
authorization := r.Header.Get("Authorization")
parts := strings.Fields(authorization)
const prefix = "Bearer "
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
msg := "invalid authorization header"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
auth := r.Header.Get("authorization")
if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
w.Header().Set("WWW-Authenticate", "Bearer")
s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
return
}
rawIDToken := auth[len(prefix):]
token := parts[1]
verified, err := s.verify(token)
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
if err == errTokenExpired {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
return
}
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(verified)
}
func (s *Server) verify(token string) ([]byte, error) {
keys, err := s.storage.GetKeys()
if err != nil {
return nil, fmt.Errorf("failed to get keys: %v", err)
}
if keys.SigningKey == nil {
return nil, fmt.Errorf("no private keys found")
}
object, err := jose.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("unable to parse signed message")
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("compact JWS format must have three parts")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
// TODO: check other claims
var tokenInfo struct {
Expiry int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
return nil, err
}
if tokenInfo.Expiry < s.now().Unix() {
return nil, errTokenExpired
}
var allKeys []*jose.JSONWebKey
allKeys = append(allKeys, keys.SigningKeyPub)
for _, key := range keys.VerificationKeys {
allKeys = append(allKeys, key.PublicKey)
var claims json.RawMessage
if err := idToken.Claims(&claims); err != nil {
s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
return
}
for _, pubKey := range allKeys {
verified, err := object.Verify(pubKey)
if err == nil {
return verified, nil
}
}
return nil, errors.New("unable to verify jwt")
w.Header().Set("Content-Type", "application/json")
w.Write(claims)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
......
package server
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
......@@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
host, _, err := net.SplitHostPort(u.Host)
return err == nil && host == "localhost"
}
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
}
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
skeys, err := s.Storage.GetKeys()
if err != nil {
return nil, err
}
keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
for _, vk := range skeys.VerificationKeys {
keys = append(keys, vk.PublicKey)
}
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(key); err == nil {
return payload, nil
}
}
}
return nil, errors.New("failed to verify id token signature")
}
......@@ -2,6 +2,8 @@ package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"net/http"
"net/http/httptest"
"net/url"
......@@ -11,6 +13,7 @@ import (
jose "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
)
func TestParseAuthorizationRequest(t *testing.T) {
......@@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
}
}
}
func TestStorageKeySet(t *testing.T) {
s := memory.New(logger)
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
keys.SigningKeyPub = &jose.JSONWebKey{
Key: testKey.Public(),
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
return keys, nil
}); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
tokenGenerator func() (jwt string, err error)
wantErr bool
}{
{
name: "valid token",
tokenGenerator: func() (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
if err != nil {
return "", err
}
jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}
return jws.CompactSerialize()
},
wantErr: false,
},
{
name: "token signed by different key",
tokenGenerator: func() (string, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", err
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
if err != nil {
return "", err
}
jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}
return jws.CompactSerialize()
},
wantErr: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
jwt, err := tc.tokenGenerator()
if err != nil {
t.Fatal(err)
}
keySet := &storageKeySet{s}
_, err = keySet.VerifySignature(context.Background(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
}
})
}
}
......@@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil
},
},
{
name: "fetch userinfo",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
if err != nil {
return fmt.Errorf("failed to fetch userinfo: %v", err)
}
return nil
},
},
{
name: "verify id token and oauth2 token expiry",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment