113 lines
2.5 KiB
Go
113 lines
2.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/api"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
const (
|
|
ErrCodeUnauthorized api.ErrorCode = "unauthorized"
|
|
ErrCodeForbidden api.ErrorCode = "forbidden"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
contextKeyUser contextKey = "user"
|
|
)
|
|
|
|
func CtxUser(ctx context.Context) (User, error) {
|
|
user, ok := ctx.Value(contextKeyUser).(User)
|
|
if !ok {
|
|
return nil, errors.Errorf("unexpected user type: expected '%T', got '%T'", new(User), ctx.Value(contextKeyUser))
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
type User interface {
|
|
Subject() string
|
|
Tenant() datastore.TenantID
|
|
}
|
|
|
|
type Authenticator interface {
|
|
Authenticate(context.Context, *http.Request) (User, error)
|
|
}
|
|
|
|
func Middleware(authenticators ...Authenticator) func(http.Handler) http.Handler {
|
|
return func(h http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr))
|
|
|
|
var (
|
|
user User
|
|
err error
|
|
)
|
|
|
|
var errs []error
|
|
|
|
for _, auth := range authenticators {
|
|
user, err = auth.Authenticate(ctx, r)
|
|
if err != nil {
|
|
errs = append(errs, errors.WithStack(err))
|
|
continue
|
|
}
|
|
|
|
if user != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if user == nil {
|
|
hasUnauthorized, hasUnauthenticated, hasUnknown := checkErrors(errs)
|
|
|
|
switch {
|
|
case hasUnauthorized && !hasUnknown:
|
|
api.ErrorResponse(w, http.StatusForbidden, api.ErrCodeForbidden, nil)
|
|
return
|
|
case hasUnauthenticated && !hasUnknown:
|
|
api.ErrorResponse(w, http.StatusUnauthorized, api.ErrCodeUnauthorized, nil)
|
|
return
|
|
case hasUnknown:
|
|
api.ErrorResponse(w, http.StatusInternalServerError, api.ErrCodeUnknownError, nil)
|
|
return
|
|
default:
|
|
api.ErrorResponse(w, http.StatusUnauthorized, ErrCodeUnauthorized, nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
ctx = logger.With(ctx, logger.F("user", user.Subject()))
|
|
ctx = context.WithValue(ctx, contextKeyUser, user)
|
|
|
|
h.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
}
|
|
|
|
func checkErrors(errs []error) (isUnauthorized bool, isUnauthenticated bool, isUnknown bool) {
|
|
isUnauthenticated = false
|
|
isUnauthorized = false
|
|
isUnknown = false
|
|
|
|
for _, e := range errs {
|
|
switch {
|
|
case errors.Is(e, ErrUnauthorized):
|
|
isUnauthorized = true
|
|
case errors.Is(e, ErrUnauthenticated):
|
|
isUnauthenticated = true
|
|
default:
|
|
isUnknown = true
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|