package auth import ( "context" "net/http" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) const ( ErrCodeUnauthorized api.ErrorCode = "unauthorized" ErrCodeForbidden api.ErrorCode = "forbidden" ) type contextKey string const ( contextKeyUser contextKey = "user" ) func CtxUser(ctx context.Context) (*User, error) { user, ok := ctx.Value(contextKeyUser).(*User) if !ok { return nil, errors.Errorf("unexpected user type: expected '%T', got '%T'", new(User), ctx.Value(contextKeyUser)) } return user, nil } var ( ErrUnauthenticated = errors.New("unauthenticated") ErrForbidden = errors.New("forbidden") ) type User interface { Subject() string } type Authenticator interface { Authenticate(context.Context, *http.Request) (User, error) } func Middleware(authenticators ...Authenticator) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr)) var ( user User err error ) for _, auth := range authenticators { user, err = auth.Authenticate(ctx, r) if err != nil { logger.Warn(ctx, "could not authenticate request", logger.E(errors.WithStack(err))) continue } if user != nil { break } } if user == nil { api.ErrorResponse(w, http.StatusUnauthorized, ErrCodeUnauthorized, nil) return } ctx = logger.With(ctx, logger.F("user", user.Subject())) ctx = context.WithValue(ctx, contextKeyUser, user) h.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) } }