feat: add basic local login handler in dev cli
This commit is contained in:
35
pkg/module/auth/http/jwt.go
Normal file
35
pkg/module/auth/http/jwt.go
Normal file
@ -0,0 +1,35 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func generateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) {
|
||||
token := jwt.New()
|
||||
|
||||
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for key, value := range claims {
|
||||
if err := token.Set(key, value); err != nil {
|
||||
return nil, errors.Wrapf(err, "could not set claim '%s' with value '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
27
pkg/module/auth/http/local_account.go
Normal file
27
pkg/module/auth/http/local_account.go
Normal file
@ -0,0 +1,27 @@
|
||||
package http
|
||||
|
||||
import "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||
|
||||
type LocalAccount struct {
|
||||
Username string `json:"username"`
|
||||
Algo passwd.Algo `json:"algo"`
|
||||
Password string `json:"password"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
}
|
||||
|
||||
func NewLocalAccount(username, password string, algo passwd.Algo, claims map[string]any) LocalAccount {
|
||||
return LocalAccount{
|
||||
Username: username,
|
||||
Password: password,
|
||||
Algo: algo,
|
||||
Claims: claims,
|
||||
}
|
||||
}
|
||||
|
||||
func toAccountsMap(accounts []LocalAccount) map[string]LocalAccount {
|
||||
accountsMap := make(map[string]LocalAccount)
|
||||
for _, acc := range accounts {
|
||||
accountsMap[acc.Username] = acc
|
||||
}
|
||||
return accountsMap
|
||||
}
|
183
pkg/module/auth/http/local_handler.go
Normal file
183
pkg/module/auth/http/local_handler.go
Normal file
@ -0,0 +1,183 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
//go:embed templates/login.html.tmpl
|
||||
var rawLoginTemplate string
|
||||
var loginTemplate *template.Template
|
||||
|
||||
var (
|
||||
errNotFound = errors.New("not found")
|
||||
errInvalidPassword = errors.New("invalid password")
|
||||
)
|
||||
|
||||
func init() {
|
||||
loginTemplate = template.Must(template.New("").Parse(rawLoginTemplate))
|
||||
}
|
||||
|
||||
type LocalHandler struct {
|
||||
router chi.Router
|
||||
algo jwa.KeyAlgorithm
|
||||
key jwk.Key
|
||||
accounts map[string]LocalAccount
|
||||
}
|
||||
|
||||
func (h *LocalHandler) initRouter(prefix string) {
|
||||
router := chi.NewRouter()
|
||||
|
||||
router.Route(prefix, func(r chi.Router) {
|
||||
r.Get("/login", h.serveForm)
|
||||
r.Post("/login", h.handleForm)
|
||||
r.Get("/logout", h.handleLogout)
|
||||
})
|
||||
|
||||
h.router = router
|
||||
}
|
||||
|
||||
type loginTemplateData struct {
|
||||
URL string
|
||||
Username string
|
||||
Password string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (h *LocalHandler) serveForm(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
data := loginTemplateData{
|
||||
URL: r.URL.String(),
|
||||
Username: "",
|
||||
Password: "",
|
||||
Message: "",
|
||||
}
|
||||
|
||||
if err := loginTemplate.Execute(w, data); err != nil {
|
||||
logger.Error(ctx, "could not execute login page template", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
logger.Error(ctx, "could not parse form", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username := r.Form.Get("username")
|
||||
password := r.Form.Get("password")
|
||||
|
||||
data := loginTemplateData{
|
||||
URL: r.URL.String(),
|
||||
Username: username,
|
||||
Password: password,
|
||||
Message: "",
|
||||
}
|
||||
|
||||
account, err := h.authenticate(username, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, errNotFound) || errors.Is(err, errInvalidPassword) {
|
||||
data.Message = "Invalid username or password."
|
||||
|
||||
if err := loginTemplate.Execute(w, data); err != nil {
|
||||
logger.Error(ctx, "could not execute login page template", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error(ctx, "could not authenticate account", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
token, err := generateSignedToken(h.algo, h.key, account.Claims)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: auth.CookieName,
|
||||
Value: string(token),
|
||||
HttpOnly: false,
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *LocalHandler) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.CookieName,
|
||||
Value: "",
|
||||
HttpOnly: false,
|
||||
Expires: time.Unix(0, 0),
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, error) {
|
||||
account, exists := h.accounts[username]
|
||||
if !exists {
|
||||
return nil, errors.WithStack(errNotFound)
|
||||
}
|
||||
|
||||
matches, err := passwd.Match(account.Algo, password, account.Password)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if !matches {
|
||||
return nil, errors.WithStack(errInvalidPassword)
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
|
||||
opts := defaultLocalHandlerOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
handler := &LocalHandler{
|
||||
algo: algo,
|
||||
key: key,
|
||||
accounts: toAccountsMap(opts.Accounts),
|
||||
}
|
||||
|
||||
handler.initRouter(opts.RoutePrefix)
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *LocalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
var _ http.Handler = &LocalHandler{}
|
27
pkg/module/auth/http/options.go
Normal file
27
pkg/module/auth/http/options.go
Normal file
@ -0,0 +1,27 @@
|
||||
package http
|
||||
|
||||
type LocalHandlerOptions struct {
|
||||
RoutePrefix string
|
||||
Accounts []LocalAccount
|
||||
}
|
||||
|
||||
type LocalHandlerOptionFunc func(*LocalHandlerOptions)
|
||||
|
||||
func defaultLocalHandlerOptions() *LocalHandlerOptions {
|
||||
return &LocalHandlerOptions{
|
||||
RoutePrefix: "",
|
||||
Accounts: make([]LocalAccount, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func WithAccounts(accounts ...LocalAccount) LocalHandlerOptionFunc {
|
||||
return func(opts *LocalHandlerOptions) {
|
||||
opts.Accounts = accounts
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoutePrefix(prefix string) LocalHandlerOptionFunc {
|
||||
return func(opts *LocalHandlerOptions) {
|
||||
opts.RoutePrefix = prefix
|
||||
}
|
||||
}
|
136
pkg/module/auth/http/passwd/argon2id/hasher.go
Normal file
136
pkg/module/auth/http/passwd/argon2id/hasher.go
Normal file
@ -0,0 +1,136 @@
|
||||
package argon2id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
Algo passwd.Algo = "argon2id"
|
||||
)
|
||||
|
||||
func init() {
|
||||
passwd.Register(Algo, &Hasher{})
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidHash = errors.New("invalid hash")
|
||||
ErrIncompatibleVersion = errors.New("incompatible version")
|
||||
)
|
||||
|
||||
type params struct {
|
||||
memory uint32
|
||||
iterations uint32
|
||||
parallelism uint8
|
||||
saltLength uint32
|
||||
keyLength uint32
|
||||
}
|
||||
|
||||
var defaultParams = params{
|
||||
memory: 64 * 1024,
|
||||
iterations: 3,
|
||||
parallelism: 2,
|
||||
saltLength: 16,
|
||||
keyLength: 32,
|
||||
}
|
||||
|
||||
type Hasher struct{}
|
||||
|
||||
// Hash implements passwd.Hasher
|
||||
func (*Hasher) Hash(plaintext string) (string, error) {
|
||||
salt, err := generateRandomBytes(defaultParams.saltLength)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(plaintext), salt, defaultParams.iterations, defaultParams.memory, defaultParams.parallelism, defaultParams.keyLength)
|
||||
|
||||
// Base64 encode the salt and hashed password.
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return a string using the standard encoded hash representation.
|
||||
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, defaultParams.memory, defaultParams.iterations, defaultParams.parallelism, b64Salt, b64Hash)
|
||||
|
||||
return encodedHash, nil
|
||||
}
|
||||
|
||||
// Match implements passwd.Hasher.
|
||||
func (*Hasher) Match(plaintext string, hash string) (bool, error) {
|
||||
matches, err := comparePasswordAndHash(plaintext, hash)
|
||||
if err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
var _ passwd.Hasher = &Hasher{}
|
||||
|
||||
func generateRandomBytes(n uint32) ([]byte, error) {
|
||||
buf := make([]byte, n)
|
||||
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func comparePasswordAndHash(password, encodedHash string) (match bool, err error) {
|
||||
p, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
otherHash := argon2.IDKey([]byte(password), salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
|
||||
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func decodeHash(encodedHash string) (p *params, salt, hash []byte, err error) {
|
||||
vals := strings.Split(encodedHash, "$")
|
||||
if len(vals) != 6 {
|
||||
return nil, nil, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, ErrIncompatibleVersion
|
||||
}
|
||||
|
||||
p = ¶ms{}
|
||||
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.saltLength = uint32(len(salt))
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.keyLength = uint32(len(hash))
|
||||
|
||||
return p, salt, hash, nil
|
||||
}
|
8
pkg/module/auth/http/passwd/hasher.go
Normal file
8
pkg/module/auth/http/passwd/hasher.go
Normal file
@ -0,0 +1,8 @@
|
||||
package passwd
|
||||
|
||||
type Algo string
|
||||
|
||||
type Hasher interface {
|
||||
Hash(plaintext string) (string, error)
|
||||
Match(plaintext string, hash string) (bool, error)
|
||||
}
|
31
pkg/module/auth/http/passwd/plain/hasher.go
Normal file
31
pkg/module/auth/http/passwd/plain/hasher.go
Normal file
@ -0,0 +1,31 @@
|
||||
package plain
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||
)
|
||||
|
||||
const (
|
||||
Algo passwd.Algo = "plain"
|
||||
)
|
||||
|
||||
func init() {
|
||||
passwd.Register(Algo, &Hasher{})
|
||||
}
|
||||
|
||||
type Hasher struct{}
|
||||
|
||||
// Hash implements passwd.Hasher
|
||||
func (*Hasher) Hash(plaintext string) (string, error) {
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Match implements passwd.Hasher.
|
||||
func (*Hasher) Match(plaintext string, hash string) (bool, error) {
|
||||
matches := subtle.ConstantTimeCompare([]byte(plaintext), []byte(hash)) == 1
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
var _ passwd.Hasher = &Hasher{}
|
87
pkg/module/auth/http/passwd/registry.go
Normal file
87
pkg/module/auth/http/passwd/registry.go
Normal file
@ -0,0 +1,87 @@
|
||||
package passwd
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var ErrAlgoNotFound = errors.New("algo not found")
|
||||
|
||||
type Registry struct {
|
||||
hashers map[Algo]Hasher
|
||||
}
|
||||
|
||||
func (r *Registry) Register(algo Algo, hasher Hasher) {
|
||||
r.hashers[algo] = hasher
|
||||
}
|
||||
|
||||
func (r *Registry) Match(algo Algo, plaintext string, hash string) (bool, error) {
|
||||
hasher, exists := r.hashers[algo]
|
||||
if !exists {
|
||||
return false, errors.WithStack(ErrAlgoNotFound)
|
||||
}
|
||||
|
||||
matches, err := hasher.Match(plaintext, hash)
|
||||
if err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Hash(algo Algo, plaintext string) (string, error) {
|
||||
hasher, exists := r.hashers[algo]
|
||||
if !exists {
|
||||
return "", errors.WithStack(ErrAlgoNotFound)
|
||||
}
|
||||
|
||||
hash, err := hasher.Hash(plaintext)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Algorithms() []Algo {
|
||||
algorithms := make([]Algo, 0, len(r.hashers))
|
||||
|
||||
for algo := range r.hashers {
|
||||
algorithms = append(algorithms, algo)
|
||||
}
|
||||
|
||||
return algorithms
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
hashers: make(map[Algo]Hasher),
|
||||
}
|
||||
}
|
||||
|
||||
var defaultRegistry = NewRegistry()
|
||||
|
||||
func Match(algo Algo, plaintext string, hash string) (bool, error) {
|
||||
matches, err := defaultRegistry.Match(algo, plaintext, hash)
|
||||
if err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func Hash(algo Algo, plaintext string) (string, error) {
|
||||
hash, err := defaultRegistry.Hash(algo, plaintext)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func Algorithms() []Algo {
|
||||
return defaultRegistry.Algorithms()
|
||||
}
|
||||
|
||||
func Register(algo Algo, hasher Hasher) {
|
||||
defaultRegistry.Register(algo, hasher)
|
||||
}
|
105
pkg/module/auth/http/templates/login.html.tmpl
Normal file
105
pkg/module/auth/http/templates/login.html.tmpl
Normal file
@ -0,0 +1,105 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>Login</title>
|
||||
<style>
|
||||
html {
|
||||
box-sizing: border-box;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
*, *:before, *:after {
|
||||
box-sizing: inherit;
|
||||
}
|
||||
|
||||
body, h1, h2, h3, h4, h5, h6, p, ol, ul {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
html, body {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
background-color: #f7f7f7;
|
||||
}
|
||||
|
||||
#container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100%;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.form-control {
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
.form-control > label {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.form-control > input {
|
||||
width: 100%;
|
||||
line-height: 1.4em;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 3px;
|
||||
font-size: 1.2em;
|
||||
padding: 0 5px;
|
||||
margin: 5px 0;
|
||||
}
|
||||
|
||||
#submit {
|
||||
float: right;
|
||||
background-color: #5e77ff;
|
||||
padding: 5px 10px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
color: white;
|
||||
font-size: 1em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#submit:hover {
|
||||
background-color: hsl(231deg 100% 71%);
|
||||
}
|
||||
|
||||
#login {
|
||||
padding: 1.5em 1em;
|
||||
border: 1px solid #e0e0e0;
|
||||
background-color: white;
|
||||
border-radius: 5px;
|
||||
box-shadow: 2px 2px #cccccc1c;
|
||||
color: #333333 !important;
|
||||
}
|
||||
|
||||
#message {
|
||||
margin-bottom: 10px;
|
||||
color: red;
|
||||
text-shadow: 1px 1px #fff0f0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="container">
|
||||
<p id="message">{{ .Message }}</p>
|
||||
<div id="login">
|
||||
<form method="post" action="{{ .URL }}">
|
||||
<div class="form-control">
|
||||
<label for="username">Username</label>
|
||||
<input type="text" id="username" name="username" value="{{ .Username }}" required />
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" value="{{ .Password }}" required />
|
||||
</div>
|
||||
<input id="submit" type="submit" value="Login" />
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
@ -5,14 +5,22 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
|
||||
const (
|
||||
CookieName string = "edge-auth"
|
||||
)
|
||||
|
||||
type GetKeySetFunc func() (jwk.Set, error)
|
||||
|
||||
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
|
||||
return func(o *Option) {
|
||||
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
|
||||
claim, err := getClaim[string](r, claimName, keyFunc)
|
||||
claim, err := getClaim[string](r, claimName, getKeySet)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
@ -22,28 +30,59 @@ func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) {
|
||||
rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
// Retrieve token from Authorization header
|
||||
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
||||
|
||||
// Retrieve token from ?edge-auth=<value>
|
||||
if rawToken == "" {
|
||||
rawToken = r.URL.Query().Get("token")
|
||||
rawToken = r.URL.Query().Get(CookieName)
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return *new(T), errors.WithStack(ErrUnauthenticated)
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if cookie != nil {
|
||||
rawToken = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(rawToken, keyFunc)
|
||||
if rawToken == "" {
|
||||
return nil, errors.WithStack(ErrUnauthenticated)
|
||||
}
|
||||
|
||||
keySet, err := getKeySet()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse([]byte(rawToken),
|
||||
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
||||
jwt.WithValidate(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getClaim[T any](r *http.Request, claimAttr string, getKeySet GetKeySetFunc) (T, error) {
|
||||
token, err := FindToken(r, getKeySet)
|
||||
if err != nil {
|
||||
return *new(T), errors.WithStack(err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw)
|
||||
}
|
||||
ctx := r.Context()
|
||||
|
||||
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims)
|
||||
mapClaims, err := token.AsMap(ctx)
|
||||
if err != nil {
|
||||
return *new(T), errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawClaim, exists := mapClaims[claimAttr]
|
||||
|
@ -2,7 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
@ -12,7 +11,9 @@ import (
|
||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
@ -22,12 +23,12 @@ func TestAuthModule(t *testing.T) {
|
||||
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
|
||||
keyFunc, secret := getKeyFunc()
|
||||
key := getDummyKey()
|
||||
|
||||
server := app.NewServer(
|
||||
module.ConsoleModuleFactory(),
|
||||
ModuleFactory(
|
||||
WithJWT(keyFunc),
|
||||
WithJWT(getDummyKeySet(key)),
|
||||
),
|
||||
)
|
||||
|
||||
@ -51,17 +52,22 @@ func TestAuthModule(t *testing.T) {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "jdoe",
|
||||
"nbf": time.Now().UTC().Unix(),
|
||||
})
|
||||
token := jwt.New()
|
||||
|
||||
rawToken, err := token.SignedString(secret)
|
||||
if err := token.Set(jwt.SubjectKey, "jdoe"); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
rawToken, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, key))
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+rawToken)
|
||||
req.Header.Add("Authorization", "Bearer "+string(rawToken))
|
||||
|
||||
ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req)
|
||||
|
||||
@ -75,11 +81,11 @@ func TestAuthAnonymousModule(t *testing.T) {
|
||||
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
|
||||
keyFunc, _ := getKeyFunc()
|
||||
key := getDummyKey()
|
||||
|
||||
server := app.NewServer(
|
||||
module.ConsoleModuleFactory(),
|
||||
ModuleFactory(WithJWT(keyFunc)),
|
||||
ModuleFactory(WithJWT(getDummyKeySet(key))),
|
||||
)
|
||||
|
||||
data, err := ioutil.ReadFile("testdata/auth_anonymous.js")
|
||||
@ -109,16 +115,29 @@ func TestAuthAnonymousModule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func getKeyFunc() (jwt.Keyfunc, []byte) {
|
||||
func getDummyKey() jwk.Key {
|
||||
secret := []byte("not_so_secret")
|
||||
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
key, err := jwk.FromRaw(secret)
|
||||
if err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
|
||||
return keyFunc, secret
|
||||
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
func getDummyKeySet(key jwk.Key) GetKeySetFunc {
|
||||
return func() (jwk.Set, error) {
|
||||
set := jwk.NewSet()
|
||||
|
||||
if err := set.AddKey(key); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user