2023-02-24 14:40:28 +01:00
|
|
|
package auth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jws"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
2023-02-24 14:40:28 +01:00
|
|
|
"github.com/pkg/errors"
|
|
|
|
)
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
const (
|
|
|
|
CookieName string = "edge-auth"
|
|
|
|
)
|
|
|
|
|
|
|
|
type GetKeySetFunc func() (jwk.Set, error)
|
|
|
|
|
|
|
|
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
|
2023-02-24 14:40:28 +01:00
|
|
|
return func(o *Option) {
|
|
|
|
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
|
2023-03-20 16:40:08 +01:00
|
|
|
claim, err := getClaim[string](r, claimName, getKeySet)
|
2023-02-24 14:40:28 +01:00
|
|
|
if err != nil {
|
|
|
|
return "", errors.WithStack(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return claim, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
|
|
|
|
authorization := r.Header.Get("Authorization")
|
|
|
|
|
|
|
|
// Retrieve token from Authorization header
|
|
|
|
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
|
|
|
|
|
|
|
// Retrieve token from ?edge-auth=<value>
|
|
|
|
if rawToken == "" {
|
|
|
|
rawToken = r.URL.Query().Get(CookieName)
|
|
|
|
}
|
|
|
|
|
2023-02-24 14:40:28 +01:00
|
|
|
if rawToken == "" {
|
2023-03-20 16:40:08 +01:00
|
|
|
cookie, err := r.Cookie(CookieName)
|
|
|
|
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
|
|
|
return nil, errors.WithStack(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if cookie != nil {
|
|
|
|
rawToken = cookie.Value
|
|
|
|
}
|
2023-02-24 14:40:28 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if rawToken == "" {
|
2023-03-20 16:40:08 +01:00
|
|
|
return nil, errors.WithStack(ErrUnauthenticated)
|
2023-02-24 14:40:28 +01:00
|
|
|
}
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
keySet, err := getKeySet()
|
2023-02-24 14:40:28 +01:00
|
|
|
if err != nil {
|
2023-03-20 16:40:08 +01:00
|
|
|
return nil, errors.WithStack(err)
|
2023-02-24 14:40:28 +01:00
|
|
|
}
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
token, err := jwt.Parse([]byte(rawToken),
|
|
|
|
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
|
|
|
jwt.WithValidate(true),
|
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
return nil, errors.WithStack(err)
|
2023-02-24 14:40:28 +01:00
|
|
|
}
|
|
|
|
|
2023-03-20 16:40:08 +01:00
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func getClaim[T any](r *http.Request, claimAttr string, getKeySet GetKeySetFunc) (T, error) {
|
|
|
|
token, err := FindToken(r, getKeySet)
|
|
|
|
if err != nil {
|
|
|
|
return *new(T), errors.WithStack(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx := r.Context()
|
|
|
|
|
|
|
|
mapClaims, err := token.AsMap(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return *new(T), errors.WithStack(err)
|
2023-02-24 14:40:28 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
rawClaim, exists := mapClaims[claimAttr]
|
|
|
|
if !exists {
|
|
|
|
return *new(T), errors.WithStack(ErrClaimNotFound)
|
|
|
|
}
|
|
|
|
|
|
|
|
claim, ok := rawClaim.(T)
|
|
|
|
if !ok {
|
|
|
|
return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim)
|
|
|
|
}
|
|
|
|
|
|
|
|
return claim, nil
|
|
|
|
}
|