emissary/internal/server/api/authorization.go

252 lines
5.1 KiB
Go

package api
import (
"context"
"fmt"
"net/http"
"forge.cadoles.com/Cadoles/emissary/internal/auth"
"forge.cadoles.com/Cadoles/emissary/internal/auth/agent"
"forge.cadoles.com/Cadoles/emissary/internal/auth/user"
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/api"
"gitlab.com/wpetit/goweb/logger"
)
var ErrCodeForbidden api.ErrorCode = "forbidden"
func assertQueryAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfRoles(user.RoleReader, user.RoleWriter, user.RoleAdmin),
nil,
)
}
func assertUserWithWriteAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfRoles(user.RoleWriter, user.RoleAdmin),
nil,
)
}
func assertAgentOrUserWithWriteAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfRoles(user.RoleWriter, user.RoleAdmin),
assertMatchingAgent(),
)
}
func assertAgentOrUserWithReadAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfRoles(user.RoleReader, user.RoleWriter, user.RoleAdmin),
assertMatchingAgent(),
)
}
func assertAdminAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfRoles(user.RoleAdmin),
nil,
)
}
func assertAdminOrTenantReadAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfUser(
assertOneOfRoles(user.RoleAdmin),
assertAllOfUser(
assertOneOfRoles(user.RoleReader, user.RoleWriter),
assertTenant(),
),
),
nil,
)
}
func assertAdminOrTenantWriteAccess(h http.Handler) http.Handler {
return assertAuthz(
h,
assertOneOfUser(
assertOneOfRoles(user.RoleAdmin),
assertAllOfUser(
assertOneOfRoles(user.RoleWriter),
assertTenant(),
),
),
nil,
)
}
func assertAuthz(h http.Handler, assertUser assertUser, assertAgent assertAgent) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
reqUser, ok := assertRequestUser(w, r)
if !ok {
return
}
switch u := reqUser.(type) {
case *user.User:
if assertUser != nil {
if ok := assertUser(w, r, u); ok {
h.ServeHTTP(w, r)
return
}
}
case *agent.User:
if assertAgent != nil {
if ok := assertAgent(w, r, u); ok {
h.ServeHTTP(w, r)
return
}
}
default:
logUnexpectedUserType(r.Context(), reqUser)
}
forbidden(w, r)
}
return http.HandlerFunc(fn)
}
type assertUser func(w http.ResponseWriter, r *http.Request, u *user.User) bool
type assertAgent func(w http.ResponseWriter, r *http.Request, u *agent.User) bool
func assertAllOfUser(funcs ...assertUser) assertUser {
return func(w http.ResponseWriter, r *http.Request, u *user.User) bool {
for _, fn := range funcs {
if ok := fn(w, r, u); !ok {
return false
}
}
return true
}
}
func assertOneOfUser(funcs ...assertUser) assertUser {
return func(w http.ResponseWriter, r *http.Request, u *user.User) bool {
for _, fn := range funcs {
if ok := fn(w, r, u); ok {
return true
}
}
return false
}
}
func assertTenant() assertUser {
return func(w http.ResponseWriter, r *http.Request, u *user.User) bool {
tenantID, ok := getTenantID(w, r)
if !ok {
return false
}
if u.Tenant() == tenantID {
return true
}
return false
}
}
func assertOneOfRoles(roles ...user.Role) assertUser {
return func(w http.ResponseWriter, r *http.Request, u *user.User) bool {
role := u.Role()
for _, rr := range roles {
if rr == role {
return true
}
}
return false
}
}
func assertMatchingAgent() assertAgent {
return func(w http.ResponseWriter, r *http.Request, u *agent.User) bool {
agentID, ok := getAgentID(w, r)
if !ok {
return false
}
agent := u.Agent()
if agent != nil && agent.ID == agentID {
return true
}
return false
}
}
func assertRequestUser(w http.ResponseWriter, r *http.Request) (auth.User, bool) {
ctx := r.Context()
user, err := auth.CtxUser(ctx)
if err != nil {
err = errors.WithStack(err)
logger.Error(ctx, "could not retrieve user", logger.CapturedE(err))
forbidden(w, r)
return nil, false
}
if user == nil || user.Tenant() == "" {
forbidden(w, r)
return nil, false
}
return user, true
}
func (m *Mount) assertTenantOwns(w http.ResponseWriter, r *http.Request, agentID datastore.AgentID) bool {
ctx := r.Context()
user, ok := assertRequestUser(w, r)
if !ok {
return false
}
agent, err := m.agentRepo.Get(ctx, agentID)
if err != nil {
err = errors.WithStack(err)
logger.Error(ctx, "could not get agent", logger.CapturedE(err))
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
}
if agent.TenantID != nil && *agent.TenantID == user.Tenant() {
return true
}
api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil)
return false
}
func forbidden(w http.ResponseWriter, r *http.Request) {
logger.Warn(r.Context(), "forbidden", logger.F("path", r.URL.Path))
api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil)
}
func logUnexpectedUserType(ctx context.Context, user auth.User) {
logger.Warn(
ctx, "unexpected user type",
logger.F("subject", user.Subject()),
logger.F("type", fmt.Sprintf("%T", user)),
)
}