61 lines
1.3 KiB
Go
61 lines
1.3 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
|
|
return func(o *Option) {
|
|
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
|
|
claim, err := getClaim[string](r, claimName, keyFunc)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
return claim, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) {
|
|
rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
|
if rawToken == "" {
|
|
rawToken = r.URL.Query().Get("token")
|
|
}
|
|
|
|
if rawToken == "" {
|
|
return *new(T), errors.WithStack(ErrUnauthenticated)
|
|
}
|
|
|
|
token, err := jwt.Parse(rawToken, keyFunc)
|
|
if err != nil {
|
|
return *new(T), errors.WithStack(err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw)
|
|
}
|
|
|
|
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims)
|
|
}
|
|
|
|
rawClaim, exists := mapClaims[claimAttr]
|
|
if !exists {
|
|
return *new(T), errors.WithStack(ErrClaimNotFound)
|
|
}
|
|
|
|
claim, ok := rawClaim.(T)
|
|
if !ok {
|
|
return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim)
|
|
}
|
|
|
|
return claim, nil
|
|
}
|