feat: add basic local login handler in dev cli

This commit is contained in:
2023-03-20 16:40:08 +01:00
parent fd12d2ba42
commit 1f4f795d43
23 changed files with 1060 additions and 145 deletions

View File

@ -0,0 +1,35 @@
package http
import (
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
func generateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) {
token := jwt.New()
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
return nil, errors.WithStack(err)
}
for key, value := range claims {
if err := token.Set(key, value); err != nil {
return nil, errors.Wrapf(err, "could not set claim '%s' with value '%v'", key, value)
}
}
if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
return nil, errors.WithStack(err)
}
rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key))
if err != nil {
return nil, errors.WithStack(err)
}
return rawToken, nil
}

View File

@ -0,0 +1,27 @@
package http
import "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
type LocalAccount struct {
Username string `json:"username"`
Algo passwd.Algo `json:"algo"`
Password string `json:"password"`
Claims map[string]any `json:"claims"`
}
func NewLocalAccount(username, password string, algo passwd.Algo, claims map[string]any) LocalAccount {
return LocalAccount{
Username: username,
Password: password,
Algo: algo,
Claims: claims,
}
}
func toAccountsMap(accounts []LocalAccount) map[string]LocalAccount {
accountsMap := make(map[string]LocalAccount)
for _, acc := range accounts {
accountsMap[acc.Username] = acc
}
return accountsMap
}

View File

@ -0,0 +1,183 @@
package http
import (
"html/template"
"net/http"
"time"
_ "embed"
"forge.cadoles.com/arcad/edge/pkg/module/auth"
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
"github.com/go-chi/chi/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
//go:embed templates/login.html.tmpl
var rawLoginTemplate string
var loginTemplate *template.Template
var (
errNotFound = errors.New("not found")
errInvalidPassword = errors.New("invalid password")
)
func init() {
loginTemplate = template.Must(template.New("").Parse(rawLoginTemplate))
}
type LocalHandler struct {
router chi.Router
algo jwa.KeyAlgorithm
key jwk.Key
accounts map[string]LocalAccount
}
func (h *LocalHandler) initRouter(prefix string) {
router := chi.NewRouter()
router.Route(prefix, func(r chi.Router) {
r.Get("/login", h.serveForm)
r.Post("/login", h.handleForm)
r.Get("/logout", h.handleLogout)
})
h.router = router
}
type loginTemplateData struct {
URL string
Username string
Password string
Message string
}
func (h *LocalHandler) serveForm(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := loginTemplateData{
URL: r.URL.String(),
Username: "",
Password: "",
Message: "",
}
if err := loginTemplate.Execute(w, data); err != nil {
logger.Error(ctx, "could not execute login page template", logger.E(errors.WithStack(err)))
}
}
func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := r.ParseForm(); err != nil {
logger.Error(ctx, "could not parse form", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
username := r.Form.Get("username")
password := r.Form.Get("password")
data := loginTemplateData{
URL: r.URL.String(),
Username: username,
Password: password,
Message: "",
}
account, err := h.authenticate(username, password)
if err != nil {
if errors.Is(err, errNotFound) || errors.Is(err, errInvalidPassword) {
data.Message = "Invalid username or password."
if err := loginTemplate.Execute(w, data); err != nil {
logger.Error(ctx, "could not execute login page template", logger.E(errors.WithStack(err)))
}
return
}
logger.Error(ctx, "could not authenticate account", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
token, err := generateSignedToken(h.algo, h.key, account.Claims)
if err != nil {
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
cookie := http.Cookie{
Name: auth.CookieName,
Value: string(token),
HttpOnly: false,
Path: "/",
}
http.SetCookie(w, &cookie)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func (h *LocalHandler) handleLogout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: auth.CookieName,
Value: "",
HttpOnly: false,
Expires: time.Unix(0, 0),
Path: "/",
})
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, error) {
account, exists := h.accounts[username]
if !exists {
return nil, errors.WithStack(errNotFound)
}
matches, err := passwd.Match(account.Algo, password, account.Password)
if err != nil {
return nil, errors.WithStack(err)
}
if !matches {
return nil, errors.WithStack(errInvalidPassword)
}
return &account, nil
}
func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
opts := defaultLocalHandlerOptions()
for _, fn := range funcs {
fn(opts)
}
handler := &LocalHandler{
algo: algo,
key: key,
accounts: toAccountsMap(opts.Accounts),
}
handler.initRouter(opts.RoutePrefix)
return handler
}
// ServeHTTP implements http.Handler.
func (h *LocalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.router.ServeHTTP(w, r)
}
var _ http.Handler = &LocalHandler{}

View File

@ -0,0 +1,27 @@
package http
type LocalHandlerOptions struct {
RoutePrefix string
Accounts []LocalAccount
}
type LocalHandlerOptionFunc func(*LocalHandlerOptions)
func defaultLocalHandlerOptions() *LocalHandlerOptions {
return &LocalHandlerOptions{
RoutePrefix: "",
Accounts: make([]LocalAccount, 0),
}
}
func WithAccounts(accounts ...LocalAccount) LocalHandlerOptionFunc {
return func(opts *LocalHandlerOptions) {
opts.Accounts = accounts
}
}
func WithRoutePrefix(prefix string) LocalHandlerOptionFunc {
return func(opts *LocalHandlerOptions) {
opts.RoutePrefix = prefix
}
}

View File

@ -0,0 +1,136 @@
package argon2id
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
"github.com/pkg/errors"
"golang.org/x/crypto/argon2"
)
const (
Algo passwd.Algo = "argon2id"
)
func init() {
passwd.Register(Algo, &Hasher{})
}
var (
ErrInvalidHash = errors.New("invalid hash")
ErrIncompatibleVersion = errors.New("incompatible version")
)
type params struct {
memory uint32
iterations uint32
parallelism uint8
saltLength uint32
keyLength uint32
}
var defaultParams = params{
memory: 64 * 1024,
iterations: 3,
parallelism: 2,
saltLength: 16,
keyLength: 32,
}
type Hasher struct{}
// Hash implements passwd.Hasher
func (*Hasher) Hash(plaintext string) (string, error) {
salt, err := generateRandomBytes(defaultParams.saltLength)
if err != nil {
return "", errors.WithStack(err)
}
hash := argon2.IDKey([]byte(plaintext), salt, defaultParams.iterations, defaultParams.memory, defaultParams.parallelism, defaultParams.keyLength)
// Base64 encode the salt and hashed password.
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// Return a string using the standard encoded hash representation.
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, defaultParams.memory, defaultParams.iterations, defaultParams.parallelism, b64Salt, b64Hash)
return encodedHash, nil
}
// Match implements passwd.Hasher.
func (*Hasher) Match(plaintext string, hash string) (bool, error) {
matches, err := comparePasswordAndHash(plaintext, hash)
if err != nil {
return false, errors.WithStack(err)
}
return matches, nil
}
var _ passwd.Hasher = &Hasher{}
func generateRandomBytes(n uint32) ([]byte, error) {
buf := make([]byte, n)
if _, err := rand.Read(buf); err != nil {
return nil, errors.WithStack(err)
}
return buf, nil
}
func comparePasswordAndHash(password, encodedHash string) (match bool, err error) {
p, salt, hash, err := decodeHash(encodedHash)
if err != nil {
return false, errors.WithStack(err)
}
otherHash := argon2.IDKey([]byte(password), salt, p.iterations, p.memory, p.parallelism, p.keyLength)
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
return true, nil
}
return false, nil
}
func decodeHash(encodedHash string) (p *params, salt, hash []byte, err error) {
vals := strings.Split(encodedHash, "$")
if len(vals) != 6 {
return nil, nil, nil, ErrInvalidHash
}
var version int
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
if err != nil {
return nil, nil, nil, err
}
if version != argon2.Version {
return nil, nil, nil, ErrIncompatibleVersion
}
p = &params{}
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
if err != nil {
return nil, nil, nil, err
}
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
if err != nil {
return nil, nil, nil, err
}
p.saltLength = uint32(len(salt))
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
if err != nil {
return nil, nil, nil, err
}
p.keyLength = uint32(len(hash))
return p, salt, hash, nil
}

View File

@ -0,0 +1,8 @@
package passwd
type Algo string
type Hasher interface {
Hash(plaintext string) (string, error)
Match(plaintext string, hash string) (bool, error)
}

View File

@ -0,0 +1,31 @@
package plain
import (
"crypto/subtle"
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
)
const (
Algo passwd.Algo = "plain"
)
func init() {
passwd.Register(Algo, &Hasher{})
}
type Hasher struct{}
// Hash implements passwd.Hasher
func (*Hasher) Hash(plaintext string) (string, error) {
return plaintext, nil
}
// Match implements passwd.Hasher.
func (*Hasher) Match(plaintext string, hash string) (bool, error) {
matches := subtle.ConstantTimeCompare([]byte(plaintext), []byte(hash)) == 1
return matches, nil
}
var _ passwd.Hasher = &Hasher{}

View File

@ -0,0 +1,87 @@
package passwd
import (
"github.com/pkg/errors"
)
var ErrAlgoNotFound = errors.New("algo not found")
type Registry struct {
hashers map[Algo]Hasher
}
func (r *Registry) Register(algo Algo, hasher Hasher) {
r.hashers[algo] = hasher
}
func (r *Registry) Match(algo Algo, plaintext string, hash string) (bool, error) {
hasher, exists := r.hashers[algo]
if !exists {
return false, errors.WithStack(ErrAlgoNotFound)
}
matches, err := hasher.Match(plaintext, hash)
if err != nil {
return false, errors.WithStack(err)
}
return matches, nil
}
func (r *Registry) Hash(algo Algo, plaintext string) (string, error) {
hasher, exists := r.hashers[algo]
if !exists {
return "", errors.WithStack(ErrAlgoNotFound)
}
hash, err := hasher.Hash(plaintext)
if err != nil {
return "", errors.WithStack(err)
}
return hash, nil
}
func (r *Registry) Algorithms() []Algo {
algorithms := make([]Algo, 0, len(r.hashers))
for algo := range r.hashers {
algorithms = append(algorithms, algo)
}
return algorithms
}
func NewRegistry() *Registry {
return &Registry{
hashers: make(map[Algo]Hasher),
}
}
var defaultRegistry = NewRegistry()
func Match(algo Algo, plaintext string, hash string) (bool, error) {
matches, err := defaultRegistry.Match(algo, plaintext, hash)
if err != nil {
return false, errors.WithStack(err)
}
return matches, nil
}
func Hash(algo Algo, plaintext string) (string, error) {
hash, err := defaultRegistry.Hash(algo, plaintext)
if err != nil {
return "", errors.WithStack(err)
}
return hash, nil
}
func Algorithms() []Algo {
return defaultRegistry.Algorithms()
}
func Register(algo Algo, hasher Hasher) {
defaultRegistry.Register(algo, hasher)
}

View File

@ -0,0 +1,105 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Login</title>
<style>
html {
box-sizing: border-box;
font-size: 16px;
}
*, *:before, *:after {
box-sizing: inherit;
}
body, h1, h2, h3, h4, h5, h6, p, ol, ul {
margin: 0;
padding: 0;
font-weight: normal;
}
html, body {
width: 100%;
height: 100%;
font-family: Arial, Helvetica, sans-serif;
background-color: #f7f7f7;
}
#container {
display: flex;
align-items: center;
justify-content: center;
height: 100%;
flex-direction: column;
}
.form-control {
margin-bottom: 0.5em;
}
.form-control > label {
font-weight: bold;
}
.form-control > input {
width: 100%;
line-height: 1.4em;
border: 1px solid #ccc;
border-radius: 3px;
font-size: 1.2em;
padding: 0 5px;
margin: 5px 0;
}
#submit {
float: right;
background-color: #5e77ff;
padding: 5px 10px;
border: none;
border-radius: 5px;
color: white;
font-size: 1em;
cursor: pointer;
}
#submit:hover {
background-color: hsl(231deg 100% 71%);
}
#login {
padding: 1.5em 1em;
border: 1px solid #e0e0e0;
background-color: white;
border-radius: 5px;
box-shadow: 2px 2px #cccccc1c;
color: #333333 !important;
}
#message {
margin-bottom: 10px;
color: red;
text-shadow: 1px 1px #fff0f0;
}
</style>
</head>
<body>
<div id="container">
<p id="message">{{ .Message }}</p>
<div id="login">
<form method="post" action="{{ .URL }}">
<div class="form-control">
<label for="username">Username</label>
<input type="text" id="username" name="username" value="{{ .Username }}" required />
</div>
<div class="form-control">
<label for="password">Password</label>
<input type="password" id="password" name="password" value="{{ .Password }}" required />
</div>
<input id="submit" type="submit" value="Login" />
</form>
</div>
</div>
</body>
</html>

View File

@ -5,14 +5,22 @@ import (
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
const (
CookieName string = "edge-auth"
)
type GetKeySetFunc func() (jwk.Set, error)
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
return func(o *Option) {
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
claim, err := getClaim[string](r, claimName, keyFunc)
claim, err := getClaim[string](r, claimName, getKeySet)
if err != nil {
return "", errors.WithStack(err)
}
@ -22,28 +30,59 @@ func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
}
}
func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) {
rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
authorization := r.Header.Get("Authorization")
// Retrieve token from Authorization header
rawToken := strings.TrimPrefix(authorization, "Bearer ")
// Retrieve token from ?edge-auth=<value>
if rawToken == "" {
rawToken = r.URL.Query().Get("token")
rawToken = r.URL.Query().Get(CookieName)
}
if rawToken == "" {
return *new(T), errors.WithStack(ErrUnauthenticated)
cookie, err := r.Cookie(CookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
return nil, errors.WithStack(err)
}
if cookie != nil {
rawToken = cookie.Value
}
}
token, err := jwt.Parse(rawToken, keyFunc)
if rawToken == "" {
return nil, errors.WithStack(ErrUnauthenticated)
}
keySet, err := getKeySet()
if err != nil {
return nil, errors.WithStack(err)
}
token, err := jwt.Parse([]byte(rawToken),
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
jwt.WithValidate(true),
)
if err != nil {
return nil, errors.WithStack(err)
}
return token, nil
}
func getClaim[T any](r *http.Request, claimAttr string, getKeySet GetKeySetFunc) (T, error) {
token, err := FindToken(r, getKeySet)
if err != nil {
return *new(T), errors.WithStack(err)
}
if !token.Valid {
return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw)
}
ctx := r.Context()
mapClaims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims)
mapClaims, err := token.AsMap(ctx)
if err != nil {
return *new(T), errors.WithStack(err)
}
rawClaim, exists := mapClaims[claimAttr]

View File

@ -2,7 +2,6 @@ package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"testing"
@ -12,7 +11,9 @@ import (
"forge.cadoles.com/arcad/edge/pkg/app"
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
"forge.cadoles.com/arcad/edge/pkg/module"
"github.com/golang-jwt/jwt"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
@ -22,12 +23,12 @@ func TestAuthModule(t *testing.T) {
logger.SetLevel(slog.LevelDebug)
keyFunc, secret := getKeyFunc()
key := getDummyKey()
server := app.NewServer(
module.ConsoleModuleFactory(),
ModuleFactory(
WithJWT(keyFunc),
WithJWT(getDummyKeySet(key)),
),
)
@ -51,17 +52,22 @@ func TestAuthModule(t *testing.T) {
t.Fatalf("%+v", errors.WithStack(err))
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "jdoe",
"nbf": time.Now().UTC().Unix(),
})
token := jwt.New()
rawToken, err := token.SignedString(secret)
if err := token.Set(jwt.SubjectKey, "jdoe"); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
rawToken, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, key))
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
req.Header.Add("Authorization", "Bearer "+rawToken)
req.Header.Add("Authorization", "Bearer "+string(rawToken))
ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req)
@ -75,11 +81,11 @@ func TestAuthAnonymousModule(t *testing.T) {
logger.SetLevel(slog.LevelDebug)
keyFunc, _ := getKeyFunc()
key := getDummyKey()
server := app.NewServer(
module.ConsoleModuleFactory(),
ModuleFactory(WithJWT(keyFunc)),
ModuleFactory(WithJWT(getDummyKeySet(key))),
)
data, err := ioutil.ReadFile("testdata/auth_anonymous.js")
@ -109,16 +115,29 @@ func TestAuthAnonymousModule(t *testing.T) {
}
}
func getKeyFunc() (jwt.Keyfunc, []byte) {
func getDummyKey() jwk.Key {
secret := []byte("not_so_secret")
keyFunc := func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
key, err := jwk.FromRaw(secret)
if err != nil {
panic(errors.WithStack(err))
}
return keyFunc, secret
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
panic(errors.WithStack(err))
}
return key
}
func getDummyKeySet(key jwk.Key) GetKeySetFunc {
return func() (jwk.Set, error) {
set := jwk.NewSet()
if err := set.AddKey(key); err != nil {
return nil, errors.WithStack(err)
}
return set, nil
}
}