2023-03-13 10:44:58 +01:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/auth"
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/auth/agent"
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty"
|
2023-10-13 12:30:52 +02:00
|
|
|
"github.com/getsentry/sentry-go"
|
2023-03-13 10:44:58 +01:00
|
|
|
"github.com/pkg/errors"
|
|
|
|
"gitlab.com/wpetit/goweb/api"
|
|
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
|
|
)
|
|
|
|
|
|
|
|
var ErrCodeForbidden api.ErrorCode = "forbidden"
|
|
|
|
|
|
|
|
func assertGlobalReadAccess(h http.Handler) http.Handler {
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
reqUser, ok := assertRequestUser(w, r)
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
switch user := reqUser.(type) {
|
|
|
|
case *thirdparty.User:
|
|
|
|
role := user.Role()
|
|
|
|
if role == thirdparty.RoleReader || role == thirdparty.RoleWriter {
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
case *agent.User:
|
|
|
|
// Agents dont have global read access
|
|
|
|
|
|
|
|
default:
|
|
|
|
logUnexpectedUserType(r.Context(), reqUser)
|
|
|
|
}
|
|
|
|
|
|
|
|
forbidden(w, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
return http.HandlerFunc(fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertAgentWriteAccess(h http.Handler) http.Handler {
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
reqUser, ok := assertRequestUser(w, r)
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
agentID, ok := getAgentID(w, r)
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
switch user := reqUser.(type) {
|
|
|
|
case *thirdparty.User:
|
|
|
|
role := user.Role()
|
|
|
|
if role == thirdparty.RoleWriter {
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
case *agent.User:
|
|
|
|
if user.Agent().ID == agentID {
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
logUnexpectedUserType(r.Context(), reqUser)
|
|
|
|
}
|
|
|
|
|
|
|
|
forbidden(w, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
return http.HandlerFunc(fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertAgentReadAccess(h http.Handler) http.Handler {
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
reqUser, ok := assertRequestUser(w, r)
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
agentID, ok := getAgentID(w, r)
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
switch user := reqUser.(type) {
|
|
|
|
case *thirdparty.User:
|
|
|
|
role := user.Role()
|
|
|
|
if role == thirdparty.RoleReader || role == thirdparty.RoleWriter {
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
case *agent.User:
|
|
|
|
if user.Agent().ID == agentID {
|
|
|
|
h.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
logUnexpectedUserType(r.Context(), reqUser)
|
|
|
|
}
|
|
|
|
|
|
|
|
forbidden(w, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
return http.HandlerFunc(fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertRequestUser(w http.ResponseWriter, r *http.Request) (auth.User, bool) {
|
|
|
|
ctx := r.Context()
|
|
|
|
user, err := auth.CtxUser(ctx)
|
|
|
|
if err != nil {
|
2023-10-13 12:30:52 +02:00
|
|
|
err = errors.WithStack(err)
|
|
|
|
logger.Error(ctx, "could not retrieve user", logger.E(err))
|
|
|
|
sentry.CaptureException(err)
|
2023-03-13 10:44:58 +01:00
|
|
|
|
|
|
|
forbidden(w, r)
|
|
|
|
|
|
|
|
return nil, false
|
|
|
|
}
|
|
|
|
|
|
|
|
if user == nil {
|
|
|
|
forbidden(w, r)
|
|
|
|
|
|
|
|
return nil, false
|
|
|
|
}
|
|
|
|
|
|
|
|
return user, true
|
|
|
|
}
|
|
|
|
|
|
|
|
func forbidden(w http.ResponseWriter, r *http.Request) {
|
|
|
|
logger.Warn(r.Context(), "forbidden", logger.F("path", r.URL.Path))
|
|
|
|
|
|
|
|
api.ErrorResponse(w, http.StatusForbidden, ErrCodeForbidden, nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
func logUnexpectedUserType(ctx context.Context, user auth.User) {
|
2023-10-13 12:30:52 +02:00
|
|
|
logger.Warn(
|
2023-03-13 10:44:58 +01:00
|
|
|
ctx, "unexpected user type",
|
|
|
|
logger.F("subject", user.Subject()),
|
|
|
|
logger.F("type", fmt.Sprintf("%T", user)),
|
|
|
|
)
|
|
|
|
}
|