feat(storage-server): jwt based authentication
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good
This commit is contained in:
6
pkg/jwtutil/error.go
Normal file
6
pkg/jwtutil/error.go
Normal file
@ -0,0 +1,6 @@
|
||||
package jwtutil
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrUnauthenticated = errors.New("unauthenticated")
|
||||
var ErrNoKeySet = errors.New("no keyset")
|
71
pkg/jwtutil/io.go
Normal file
71
pkg/jwtutil/io.go
Normal file
@ -0,0 +1,71 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"os"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func LoadOrGenerateKey(path string, defaultKeySize int) (jwk.Key, error) {
|
||||
key, err := LoadKey(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err = GenerateKey(defaultKeySize)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := SaveKey(path, key); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func LoadKey(path string) (jwk.Key, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func SaveKey(path string, key jwk.Key) error {
|
||||
data, err := jwk.Pem(key)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, os.FileMode(0600)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateKey(keySize int) (jwk.Key, error) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, keySize)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err := jwk.FromRaw(rsaKey)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
52
pkg/jwtutil/key.go
Normal file
52
pkg/jwtutil/key.go
Normal file
@ -0,0 +1,52 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func NewKeySet(keys ...jwk.Key) (jwk.Set, error) {
|
||||
set := jwk.NewSet()
|
||||
|
||||
for _, k := range keys {
|
||||
if err := set.AddKey(k); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func NewSymmetricKey(secret []byte) (jwk.Key, error) {
|
||||
key, err := jwk.FromRaw(secret)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) {
|
||||
keys := make([]jwk.Key, len(secrets))
|
||||
|
||||
for idx, sec := range secrets {
|
||||
key, err := NewSymmetricKey(sec)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
keys[idx] = key
|
||||
}
|
||||
|
||||
keySet, err := NewKeySet(keys...)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return keySet, nil
|
||||
}
|
123
pkg/jwtutil/request.go
Normal file
123
pkg/jwtutil/request.go
Normal file
@ -0,0 +1,123 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TokenFinderFunc func(r *http.Request) (string, error)
|
||||
|
||||
type FindTokenOptions struct {
|
||||
Finders []TokenFinderFunc
|
||||
}
|
||||
|
||||
type FindTokenOptionFunc func(*FindTokenOptions)
|
||||
|
||||
type GetKeySetFunc func() (jwk.Set, error)
|
||||
|
||||
func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc {
|
||||
return func(opts *FindTokenOptions) {
|
||||
opts.Finders = finders
|
||||
}
|
||||
}
|
||||
|
||||
func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions {
|
||||
opts := &FindTokenOptions{
|
||||
Finders: []TokenFinderFunc{
|
||||
FindTokenFromAuthorizationHeader,
|
||||
},
|
||||
}
|
||||
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
// Retrieve token from Authorization header
|
||||
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
func FindTokenFromQueryString(name string) TokenFinderFunc {
|
||||
return func(r *http.Request) (string, error) {
|
||||
return r.URL.Query().Get(name), nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindTokenFromCookie(cookieName string) TokenFinderFunc {
|
||||
return func(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
if cookie == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return cookie.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) {
|
||||
opts := NewFindTokenOptions(funcs...)
|
||||
|
||||
var rawToken string
|
||||
var err error
|
||||
|
||||
for _, find := range opts.Finders {
|
||||
rawToken, err = find(r)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return "", errors.WithStack(ErrUnauthenticated)
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) {
|
||||
rawToken, err := FindRawToken(r, funcs...)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
keySet, err := getKeySet()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if keySet == nil {
|
||||
return nil, errors.WithStack(ErrNoKeySet)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse([]byte(rawToken),
|
||||
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
||||
jwt.WithValidate(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
35
pkg/jwtutil/token.go
Normal file
35
pkg/jwtutil/token.go
Normal file
@ -0,0 +1,35 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func SignedToken(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, claims map[string]any) ([]byte, error) {
|
||||
token := jwt.New()
|
||||
|
||||
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for key, value := range claims {
|
||||
if err := token.Set(key, value); err != nil {
|
||||
return nil, errors.Wrapf(err, "could not set claim '%s' with value '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if err := token.Set(jwk.AlgorithmKey, signingAlgorithm); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawToken, err := jwt.Sign(token, jwt.WithKey(signingAlgorithm, key))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
Reference in New Issue
Block a user