emissary/internal/auth/middleware.go

113 lines
2.5 KiB
Go
Raw Normal View History

package auth
import (
"context"
"net/http"
2024-02-26 18:20:40 +01:00
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/api"
"gitlab.com/wpetit/goweb/logger"
)
const (
ErrCodeUnauthorized api.ErrorCode = "unauthorized"
ErrCodeForbidden api.ErrorCode = "forbidden"
)
type contextKey string
const (
contextKeyUser contextKey = "user"
)
2023-03-13 10:44:58 +01:00
func CtxUser(ctx context.Context) (User, error) {
user, ok := ctx.Value(contextKeyUser).(User)
if !ok {
return nil, errors.Errorf("unexpected user type: expected '%T', got '%T'", new(User), ctx.Value(contextKeyUser))
}
return user, nil
}
type User interface {
Subject() string
2024-02-26 18:20:40 +01:00
Tenant() datastore.TenantID
}
type Authenticator interface {
Authenticate(context.Context, *http.Request) (User, error)
}
func Middleware(authenticators ...Authenticator) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr))
var (
user User
err error
)
2024-02-26 18:20:40 +01:00
var errs []error
for _, auth := range authenticators {
user, err = auth.Authenticate(ctx, r)
if err != nil {
2024-02-26 18:20:40 +01:00
errs = append(errs, errors.WithStack(err))
continue
}
if user != nil {
break
}
}
if user == nil {
hasUnauthorized, hasUnauthenticated, hasUnknown := checkErrors(errs)
2024-02-26 18:20:40 +01:00
switch {
case hasUnauthorized && !hasUnknown:
2024-02-26 18:20:40 +01:00
api.ErrorResponse(w, http.StatusForbidden, api.ErrCodeForbidden, nil)
return
case hasUnauthenticated && !hasUnknown:
api.ErrorResponse(w, http.StatusUnauthorized, api.ErrCodeUnauthorized, nil)
2024-02-26 18:20:40 +01:00
return
case hasUnknown:
2024-02-26 18:20:40 +01:00
api.ErrorResponse(w, http.StatusInternalServerError, api.ErrCodeUnknownError, nil)
return
default:
api.ErrorResponse(w, http.StatusUnauthorized, ErrCodeUnauthorized, nil)
return
}
}
ctx = logger.With(ctx, logger.F("user", user.Subject()))
ctx = context.WithValue(ctx, contextKeyUser, user)
h.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}
2024-02-26 18:20:40 +01:00
func checkErrors(errs []error) (isUnauthorized bool, isUnauthenticated bool, isUnknown bool) {
isUnauthenticated = false
isUnauthorized = false
isUnknown = false
for _, e := range errs {
switch {
case errors.Is(e, ErrUnauthorized):
isUnauthorized = true
case errors.Is(e, ErrUnauthenticated):
isUnauthenticated = true
default:
isUnknown = true
}
}
return
}