120 lines
2.3 KiB
Go
120 lines
2.3 KiB
Go
package jwtutil
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type TokenFinderFunc func(r *http.Request) (string, error)
|
|
|
|
type FindTokenOptions struct {
|
|
Finders []TokenFinderFunc
|
|
}
|
|
|
|
type FindTokenOptionFunc func(*FindTokenOptions)
|
|
|
|
type GetKeySetFunc func() (jwk.Set, error)
|
|
|
|
func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc {
|
|
return func(opts *FindTokenOptions) {
|
|
opts.Finders = finders
|
|
}
|
|
}
|
|
|
|
func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions {
|
|
opts := &FindTokenOptions{
|
|
Finders: []TokenFinderFunc{
|
|
FindTokenFromAuthorizationHeader,
|
|
},
|
|
}
|
|
|
|
for _, fn := range funcs {
|
|
fn(opts)
|
|
}
|
|
|
|
return opts
|
|
}
|
|
|
|
func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) {
|
|
authorization := r.Header.Get("Authorization")
|
|
|
|
// Retrieve token from Authorization header
|
|
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
|
|
|
return rawToken, nil
|
|
}
|
|
|
|
func FindTokenFromQueryString(name string) TokenFinderFunc {
|
|
return func(r *http.Request) (string, error) {
|
|
return r.URL.Query().Get(name), nil
|
|
}
|
|
}
|
|
|
|
func FindTokenFromCookie(cookieName string) TokenFinderFunc {
|
|
return func(r *http.Request) (string, error) {
|
|
cookie, err := r.Cookie(cookieName)
|
|
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
if cookie == nil {
|
|
return "", nil
|
|
}
|
|
|
|
return cookie.Value, nil
|
|
}
|
|
}
|
|
|
|
func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) {
|
|
opts := NewFindTokenOptions(funcs...)
|
|
|
|
var rawToken string
|
|
var err error
|
|
|
|
for _, find := range opts.Finders {
|
|
rawToken, err = find(r)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
if rawToken == "" {
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
if rawToken == "" {
|
|
return "", errors.WithStack(ErrUnauthenticated)
|
|
}
|
|
|
|
return rawToken, nil
|
|
}
|
|
|
|
func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) {
|
|
rawToken, err := FindRawToken(r, funcs...)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
keySet, err := getKeySet()
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
if keySet == nil {
|
|
return nil, errors.WithStack(ErrNoKeySet)
|
|
}
|
|
|
|
token, err := Parse([]byte(rawToken), keySet)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return token, nil
|
|
}
|