package server import ( "context" "fmt" "net/http" "forge.cadoles.com/Cadoles/emissary/internal/auth" "forge.cadoles.com/Cadoles/emissary/internal/auth/agent" "forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) var ErrCodeForbidden api.ErrorCode = "forbidden" func assertGlobalReadAccess(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { reqUser, ok := assertRequestUser(w, r) if !ok { return } switch user := reqUser.(type) { case *thirdparty.User: role := user.Role() if role == thirdparty.RoleReader || role == thirdparty.RoleWriter { h.ServeHTTP(w, r) return } case *agent.User: // Agents dont have global read access default: logUnexpectedUserType(r.Context(), reqUser) } forbidden(w, r) } return http.HandlerFunc(fn) } func assertAgentWriteAccess(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { reqUser, ok := assertRequestUser(w, r) if !ok { return } agentID, ok := getAgentID(w, r) if !ok { return } switch user := reqUser.(type) { case *thirdparty.User: role := user.Role() if role == thirdparty.RoleWriter { h.ServeHTTP(w, r) return } case *agent.User: if user.Agent().ID == agentID { h.ServeHTTP(w, r) return } default: logUnexpectedUserType(r.Context(), reqUser) } forbidden(w, r) } return http.HandlerFunc(fn) } func assertAgentReadAccess(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { reqUser, ok := assertRequestUser(w, r) if !ok { return } agentID, ok := getAgentID(w, r) if !ok { return } switch user := reqUser.(type) { case *thirdparty.User: role := user.Role() if role == thirdparty.RoleReader || role == thirdparty.RoleWriter { h.ServeHTTP(w, r) return } case *agent.User: if user.Agent().ID == agentID { h.ServeHTTP(w, r) return } default: logUnexpectedUserType(r.Context(), reqUser) } forbidden(w, r) } return http.HandlerFunc(fn) } func assertRequestUser(w http.ResponseWriter, r *http.Request) (auth.User, bool) { ctx := r.Context() user, err := auth.CtxUser(ctx) if err != nil { logger.Error(ctx, "could not retrieve user", logger.E(errors.WithStack(err))) forbidden(w, r) return nil, false } if user == nil { forbidden(w, r) return nil, false } return user, true } func forbidden(w http.ResponseWriter, r *http.Request) { logger.Warn(r.Context(), "forbidden", logger.F("path", r.URL.Path)) api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil) } func logUnexpectedUserType(ctx context.Context, user auth.User) { logger.Error( ctx, "unexpected user type", logger.F("subject", user.Subject()), logger.F("type", fmt.Sprintf("%T", user)), ) }