Skip to content
Snippets Groups Projects
main.go 8.87 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    package main
    
    import (
    	"bytes"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"crypto/tls"
    	"crypto/x509"
    	"encoding/json"
    	"errors"
    	"fmt"
    	"io/ioutil"
    	"log"
    	"net"
    	"net/http"
    
    	"net/http/httputil"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"net/url"
    	"os"
    	"strings"
    	"time"
    
    
    	"github.com/coreos/go-oidc"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"github.com/spf13/cobra"
    	"golang.org/x/oauth2"
    )
    
    
    const exampleAppState = "I wish to wash my irish wristwatch"
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    type app struct {
    	clientID     string
    	clientSecret string
    	redirectURI  string
    
    	verifier *oidc.IDTokenVerifier
    	provider *oidc.Provider
    
    
    	// Does the provider use "offline_access" scope to request a refresh token
    	// or does it use "access_type=offline" (e.g. Google)?
    	offlineAsScope bool
    
    
    	client *http.Client
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    // return an HTTP client which trusts the provided root CAs.
    func httpClientForRootCAs(rootCAs string) (*http.Client, error) {
    	tlsConfig := tls.Config{RootCAs: x509.NewCertPool()}
    	rootCABytes, err := ioutil.ReadFile(rootCAs)
    	if err != nil {
    		return nil, fmt.Errorf("failed to read root-ca: %v", err)
    	}
    	if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
    		return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs)
    	}
    	return &http.Client{
    		Transport: &http.Transport{
    			TLSClientConfig: &tlsConfig,
    			Proxy:           http.ProxyFromEnvironment,
    			Dial: (&net.Dialer{
    				Timeout:   30 * time.Second,
    				KeepAlive: 30 * time.Second,
    			}).Dial,
    			TLSHandshakeTimeout:   10 * time.Second,
    			ExpectContinueTimeout: 1 * time.Second,
    		},
    	}, nil
    }
    
    
    type debugTransport struct {
    	t http.RoundTripper
    }
    
    func (d debugTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    	reqDump, err := httputil.DumpRequest(req, true)
    	if err != nil {
    		return nil, err
    	}
    	log.Printf("%s", reqDump)
    
    	resp, err := d.t.RoundTrip(req)
    	if err != nil {
    		return nil, err
    	}
    
    	respDump, err := httputil.DumpResponse(resp, true)
    	if err != nil {
    		resp.Body.Close()
    		return nil, err
    	}
    	log.Printf("%s", respDump)
    	return resp, nil
    }
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    func cmd() *cobra.Command {
    	var (
    		a         app
    		issuerURL string
    		listen    string
    		tlsCert   string
    		tlsKey    string
    		rootCAs   string
    
    		debug     bool
    
    Eric Chiang's avatar
    Eric Chiang committed
    	)
    	c := cobra.Command{
    		Use:   "example-app",
    		Short: "An example OpenID Connect client",
    		Long:  "",
    		RunE: func(cmd *cobra.Command, args []string) error {
    			if len(args) != 0 {
    				return errors.New("surplus arguments provided")
    			}
    
    			u, err := url.Parse(a.redirectURI)
    			if err != nil {
    				return fmt.Errorf("parse redirect-uri: %v", err)
    			}
    			listenURL, err := url.Parse(listen)
    			if err != nil {
    				return fmt.Errorf("parse listen address: %v", err)
    			}
    
    			if rootCAs != "" {
    				client, err := httpClientForRootCAs(rootCAs)
    				if err != nil {
    					return err
    				}
    
    				a.client = client
    
    			if debug {
    
    				if a.client == nil {
    					a.client = &http.Client{
    
    						Transport: debugTransport{http.DefaultTransport},
    
    					}
    				} else {
    					a.client.Transport = debugTransport{a.client.Transport}
    
    			if a.client == nil {
    				a.client = http.DefaultClient
    			}
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    			// TODO(ericchiang): Retry with backoff
    
    			ctx := oidc.ClientContext(context.Background(), a.client)
    			provider, err := oidc.NewProvider(ctx, issuerURL)
    
    Eric Chiang's avatar
    Eric Chiang committed
    			if err != nil {
    				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err)
    			}
    
    
    			var s struct {
    				// What scopes does a provider support?
    				//
    				// See: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
    				ScopesSupported []string `json:"scopes_supported"`
    			}
    			if err := provider.Claims(&s); err != nil {
    				return fmt.Errorf("Failed to parse provider scopes_supported: %v", err)
    			}
    
    			if len(s.ScopesSupported) == 0 {
    				// scopes_supported is a "RECOMMENDED" discovery claim, not a required
    				// one. If missing, assume that the provider follows the spec and has
    				// an "offline_access" scope.
    				a.offlineAsScope = true
    			} else {
    				// See if scopes_supported has the "offline_access" scope.
    				a.offlineAsScope = func() bool {
    					for _, scope := range s.ScopesSupported {
    						if scope == oidc.ScopeOfflineAccess {
    							return true
    						}
    					}
    					return false
    				}()
    			}
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    			a.provider = provider
    
    			a.verifier = provider.Verifier(&oidc.Config{ClientID: a.clientID})
    
    Eric Chiang's avatar
    Eric Chiang committed
    
    			http.HandleFunc("/", a.handleIndex)
    			http.HandleFunc("/login", a.handleLogin)
    			http.HandleFunc(u.Path, a.handleCallback)
    
    			switch listenURL.Scheme {
    			case "http":
    				log.Printf("listening on %s", listen)
    				return http.ListenAndServe(listenURL.Host, nil)
    			case "https":
    				log.Printf("listening on %s", listen)
    				return http.ListenAndServeTLS(listenURL.Host, tlsCert, tlsKey, nil)
    			default:
    				return fmt.Errorf("listen address %q is not using http or https", listen)
    			}
    		},
    	}
    	c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.")
    	c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.")
    	c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.")
    
    	c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556/dex", "URL of the OpenID Connect issuer.")
    
    Eric Chiang's avatar
    Eric Chiang committed
    	c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.")
    	c.Flags().StringVar(&tlsCert, "tls-cert", "", "X509 cert file to present when serving HTTPS.")
    	c.Flags().StringVar(&tlsKey, "tls-key", "", "Private key for the HTTPS cert.")
    	c.Flags().StringVar(&rootCAs, "issuer-root-ca", "", "Root certificate authorities for the issuer. Defaults to host certs.")
    
    	c.Flags().BoolVar(&debug, "debug", false, "Print all request and responses from the OpenID Connect issuer.")
    
    Eric Chiang's avatar
    Eric Chiang committed
    	return &c
    }
    
    func main() {
    	if err := cmd().Execute(); err != nil {
    		fmt.Fprintf(os.Stderr, "error: %v\n", err)
    		os.Exit(2)
    	}
    }
    
    func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) {
    	renderIndex(w)
    }
    
    func (a *app) oauth2Config(scopes []string) *oauth2.Config {
    	return &oauth2.Config{
    		ClientID:     a.clientID,
    		ClientSecret: a.clientSecret,
    		Endpoint:     a.provider.Endpoint(),
    		Scopes:       scopes,
    		RedirectURL:  a.redirectURI,
    	}
    }
    
    func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
    	var scopes []string
    	if extraScopes := r.FormValue("extra_scopes"); extraScopes != "" {
    		scopes = strings.Split(extraScopes, " ")
    	}
    	var clients []string
    	if crossClients := r.FormValue("cross_client"); crossClients != "" {
    		clients = strings.Split(crossClients, " ")
    	}
    	for _, client := range clients {
    		scopes = append(scopes, "audience:server:client_id:"+client)
    	}
    
    
    	authCodeURL := ""
    	scopes = append(scopes, "openid", "profile", "email")
    
    	if r.FormValue("offline_access") != "yes" {
    
    		authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
    	} else if a.offlineAsScope {
    
    		scopes = append(scopes, "offline_access")
    
    		authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
    
    		authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, oauth2.AccessTypeOffline)
    
    	http.Redirect(w, r, authCodeURL, http.StatusSeeOther)
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
    	var (
    		err   error
    		token *oauth2.Token
    	)
    
    
    	ctx := oidc.ClientContext(r.Context(), a.client)
    
    Eric Chiang's avatar
    Eric Chiang committed
    	oauth2Config := a.oauth2Config(nil)
    
    	switch r.Method {
    	case "GET":
    		// Authorization redirect callback from OAuth2 auth flow.
    		if errMsg := r.FormValue("error"); errMsg != "" {
    			http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
    			return
    		}
    		code := r.FormValue("code")
    		if code == "" {
    			http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
    			return
    		}
    		if state := r.FormValue("state"); state != exampleAppState {
    			http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
    			return
    		}
    
    		token, err = oauth2Config.Exchange(ctx, code)
    
    	case "POST":
    		// Form request from frontend to refresh a token.
    		refresh := r.FormValue("refresh_token")
    		if refresh == "" {
    			http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest)
    			return
    		}
    
    Eric Chiang's avatar
    Eric Chiang committed
    		t := &oauth2.Token{
    			RefreshToken: refresh,
    			Expiry:       time.Now().Add(-time.Hour),
    		}
    
    		token, err = oauth2Config.TokenSource(ctx, t).Token()
    
    Eric Chiang's avatar
    Eric Chiang committed
    	default:
    
    		http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
    
    Eric Chiang's avatar
    Eric Chiang committed
    		return
    	}
    
    	if err != nil {
    		http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError)
    		return
    	}
    
    	rawIDToken, ok := token.Extra("id_token").(string)
    	if !ok {
    		http.Error(w, "no id_token in token response", http.StatusInternalServerError)
    		return
    	}
    
    
    	idToken, err := a.verifier.Verify(r.Context(), rawIDToken)
    
    Eric Chiang's avatar
    Eric Chiang committed
    	if err != nil {
    		http.Error(w, fmt.Sprintf("Failed to verify ID token: %v", err), http.StatusInternalServerError)
    		return
    	}
    	var claims json.RawMessage
    	idToken.Claims(&claims)
    
    	buff := new(bytes.Buffer)
    	json.Indent(buff, []byte(claims), "", "  ")
    
    	renderToken(w, a.redirectURI, rawIDToken, token.RefreshToken, buff.Bytes())
    }