package api import ( "context" "fmt" "net/http" "forge.cadoles.com/Cadoles/emissary/internal/auth" "forge.cadoles.com/Cadoles/emissary/internal/auth/agent" "forge.cadoles.com/Cadoles/emissary/internal/auth/user" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) var ErrCodeForbidden api.ErrorCode = "forbidden" func assertQueryAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfRoles(user.RoleReader, user.RoleWriter, user.RoleAdmin), nil, ) } func assertUserWithWriteAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfRoles(user.RoleWriter, user.RoleAdmin), nil, ) } func assertAgentOrUserWithWriteAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfRoles(user.RoleWriter, user.RoleAdmin), assertMatchingAgent(), ) } func assertAgentOrUserWithReadAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfRoles(user.RoleReader, user.RoleWriter, user.RoleAdmin), assertMatchingAgent(), ) } func assertAdminAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfRoles(user.RoleAdmin), nil, ) } func assertAdminOrTenantReadAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfUser( assertOneOfRoles(user.RoleAdmin), assertAllOfUser( assertOneOfRoles(user.RoleReader, user.RoleWriter), assertTenant(), ), ), nil, ) } func assertAdminOrTenantWriteAccess(h http.Handler) http.Handler { return assertAuthz( h, assertOneOfUser( assertOneOfRoles(user.RoleAdmin), assertAllOfUser( assertOneOfRoles(user.RoleWriter), assertTenant(), ), ), nil, ) } func assertAuthz(h http.Handler, assertUser assertUser, assertAgent assertAgent) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { reqUser, ok := assertRequestUser(w, r) if !ok { return } switch u := reqUser.(type) { case *user.User: if assertUser != nil { if ok := assertUser(w, r, u); ok { h.ServeHTTP(w, r) return } } case *agent.User: if assertAgent != nil { if ok := assertAgent(w, r, u); ok { h.ServeHTTP(w, r) return } } default: logUnexpectedUserType(r.Context(), reqUser) } forbidden(w, r) } return http.HandlerFunc(fn) } type assertUser func(w http.ResponseWriter, r *http.Request, u *user.User) bool type assertAgent func(w http.ResponseWriter, r *http.Request, u *agent.User) bool func assertAllOfUser(funcs ...assertUser) assertUser { return func(w http.ResponseWriter, r *http.Request, u *user.User) bool { for _, fn := range funcs { if ok := fn(w, r, u); !ok { return false } } return true } } func assertOneOfUser(funcs ...assertUser) assertUser { return func(w http.ResponseWriter, r *http.Request, u *user.User) bool { for _, fn := range funcs { if ok := fn(w, r, u); ok { return true } } return false } } func assertTenant() assertUser { return func(w http.ResponseWriter, r *http.Request, u *user.User) bool { tenantID, ok := getTenantID(w, r) if !ok { return false } if u.Tenant() == tenantID { return true } return false } } func assertOneOfRoles(roles ...user.Role) assertUser { return func(w http.ResponseWriter, r *http.Request, u *user.User) bool { role := u.Role() for _, rr := range roles { if rr == role { return true } } return false } } func assertMatchingAgent() assertAgent { return func(w http.ResponseWriter, r *http.Request, u *agent.User) bool { agentID, ok := getAgentID(w, r) if !ok { return false } agent := u.Agent() if agent != nil && agent.ID == agentID { return true } return false } } func assertRequestUser(w http.ResponseWriter, r *http.Request) (auth.User, bool) { ctx := r.Context() user, err := auth.CtxUser(ctx) if err != nil { err = errors.WithStack(err) logger.Error(ctx, "could not retrieve user", logger.CapturedE(err)) forbidden(w, r) return nil, false } if user == nil || user.Tenant() == "" { forbidden(w, r) return nil, false } return user, true } func (m *Mount) assertTenantOwns(w http.ResponseWriter, r *http.Request, agentID datastore.AgentID) bool { ctx := r.Context() user, ok := assertRequestUser(w, r) if !ok { return false } agent, err := m.agentRepo.Get(ctx, agentID) if err != nil { err = errors.WithStack(err) logger.Error(ctx, "could not get agent", logger.CapturedE(err)) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) } if agent.TenantID != nil && *agent.TenantID == user.Tenant() { return true } api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil) return false } func forbidden(w http.ResponseWriter, r *http.Request) { logger.Warn(r.Context(), "forbidden", logger.F("path", r.URL.Path)) api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil) } func logUnexpectedUserType(ctx context.Context, user auth.User) { logger.Warn( ctx, "unexpected user type", logger.F("subject", user.Subject()), logger.F("type", fmt.Sprintf("%T", user)), ) }