package jwtutil import ( "net/http" "strings" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" ) type TokenFinderFunc func(r *http.Request) (string, error) type FindTokenOptions struct { Finders []TokenFinderFunc } type FindTokenOptionFunc func(*FindTokenOptions) type GetKeySetFunc func() (jwk.Set, error) func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc { return func(opts *FindTokenOptions) { opts.Finders = finders } } func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions { opts := &FindTokenOptions{ Finders: []TokenFinderFunc{ FindTokenFromAuthorizationHeader, }, } for _, fn := range funcs { fn(opts) } return opts } func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) { authorization := r.Header.Get("Authorization") // Retrieve token from Authorization header rawToken := strings.TrimPrefix(authorization, "Bearer ") return rawToken, nil } func FindTokenFromQueryString(name string) TokenFinderFunc { return func(r *http.Request) (string, error) { return r.URL.Query().Get(name), nil } } func FindTokenFromCookie(cookieName string) TokenFinderFunc { return func(r *http.Request) (string, error) { cookie, err := r.Cookie(cookieName) if err != nil && !errors.Is(err, http.ErrNoCookie) { return "", errors.WithStack(err) } if cookie == nil { return "", nil } return cookie.Value, nil } } func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) { opts := NewFindTokenOptions(funcs...) var rawToken string var err error for _, find := range opts.Finders { rawToken, err = find(r) if err != nil { return "", errors.WithStack(err) } if rawToken == "" { continue } break } if rawToken == "" { return "", errors.WithStack(ErrUnauthenticated) } return rawToken, nil } func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) { rawToken, err := FindRawToken(r, funcs...) if err != nil { return nil, errors.WithStack(err) } keySet, err := getKeySet() if err != nil { return nil, errors.WithStack(err) } if keySet == nil { return nil, errors.WithStack(ErrNoKeySet) } token, err := Parse([]byte(rawToken), keySet) if err != nil { return nil, errors.WithStack(err) } return token, nil }