package auth import ( "context" "net/http" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) const ( ErrCodeUnauthorized api.ErrorCode = "unauthorized" ErrCodeForbidden api.ErrorCode = "forbidden" ) type contextKey string const ( contextKeyUser contextKey = "user" ) func CtxUser(ctx context.Context) (User, error) { user, ok := ctx.Value(contextKeyUser).(User) if !ok { return nil, errors.Errorf("unexpected user type: expected '%T', got '%T'", new(User), ctx.Value(contextKeyUser)) } return user, nil } type User interface { Subject() string Tenant() datastore.TenantID } type Authenticator interface { Authenticate(context.Context, *http.Request) (User, error) } func Middleware(authenticators ...Authenticator) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr)) var ( user User err error ) var errs []error for _, auth := range authenticators { user, err = auth.Authenticate(ctx, r) if err != nil { errs = append(errs, errors.WithStack(err)) continue } if user != nil { break } } if user == nil { hasUnauthorized, hasUnauthenticated, hasUnknown := checkErrors(errs) switch { case hasUnauthorized && !hasUnknown: api.ErrorResponse(w, http.StatusForbidden, api.ErrCodeForbidden, nil) return case hasUnauthenticated && !hasUnknown: api.ErrorResponse(w, http.StatusUnauthorized, api.ErrCodeUnauthorized, nil) return case hasUnknown: api.ErrorResponse(w, http.StatusInternalServerError, api.ErrCodeUnknownError, nil) return default: api.ErrorResponse(w, http.StatusUnauthorized, ErrCodeUnauthorized, nil) return } } ctx = logger.With(ctx, logger.F("user", user.Subject())) ctx = context.WithValue(ctx, contextKeyUser, user) h.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) } } func checkErrors(errs []error) (isUnauthorized bool, isUnauthenticated bool, isUnknown bool) { isUnauthenticated = false isUnauthorized = false isUnknown = false for _, e := range errs { switch { case errors.Is(e, ErrUnauthorized): isUnauthorized = true case errors.Is(e, ErrUnauthenticated): isUnauthenticated = true default: isUnknown = true } } return }