421 lines
10 KiB
Go
421 lines
10 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/agent/metadata"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/jwk"
|
|
"github.com/go-chi/chi"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/api"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
const (
|
|
ErrCodeUnknownError api.ErrorCode = "unknown-error"
|
|
ErrCodeNotFound api.ErrorCode = "not-found"
|
|
ErrCodeInvalidSignature api.ErrorCode = "invalid-signature"
|
|
ErrCodeConflict api.ErrorCode = "conflict"
|
|
)
|
|
|
|
type registerAgentRequest struct {
|
|
KeySet json.RawMessage `json:"keySet" validate:"required"`
|
|
Metadata []metadata.Tuple `json:"metadata" validate:"required"`
|
|
Thumbprint string `json:"thumbprint" validate:"required"`
|
|
Signature string `json:"signature" validate:"required"`
|
|
}
|
|
|
|
func (s *Server) registerAgent(w http.ResponseWriter, r *http.Request) {
|
|
registerAgentReq := ®isterAgentRequest{}
|
|
if ok := api.Bind(w, r, registerAgentReq); !ok {
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
|
|
keySet, err := jwk.Parse(registerAgentReq.KeySet)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not parse key set", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
ctx = logger.With(ctx, logger.F("agentThumbprint", registerAgentReq.Thumbprint))
|
|
|
|
// Validate that the existing signature validates the request
|
|
|
|
validSignature, err := jwk.Verify(keySet, registerAgentReq.Signature, registerAgentReq.Thumbprint, registerAgentReq.Metadata)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not validate signature", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
if !validSignature {
|
|
logger.Warn(ctx, "conflicting signature", logger.F("signature", registerAgentReq.Signature))
|
|
api.ErrorResponse(w, http.StatusConflict, ErrCodeConflict, nil)
|
|
|
|
return
|
|
}
|
|
|
|
metadata := metadata.FromSorted(registerAgentReq.Metadata)
|
|
|
|
agent, err := s.agentRepo.Create(
|
|
ctx,
|
|
registerAgentReq.Thumbprint,
|
|
keySet,
|
|
metadata,
|
|
)
|
|
if err != nil {
|
|
if !errors.Is(err, datastore.ErrAlreadyExist) {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not create agent", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
agents, _, err := s.agentRepo.Query(
|
|
ctx,
|
|
datastore.WithAgentQueryThumbprints(registerAgentReq.Thumbprint),
|
|
datastore.WithAgentQueryLimit(1),
|
|
)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not retrieve agents", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
if len(agents) == 0 {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not retrieve matching agent", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeNotFound, nil)
|
|
|
|
return
|
|
}
|
|
|
|
agentID := agents[0].ID
|
|
|
|
agent, err = s.agentRepo.Get(ctx, agentID)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(
|
|
ctx, "could not retrieve agent",
|
|
logger.CapturedE(err), logger.F("agentID", agentID),
|
|
)
|
|
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
validSignature, err = jwk.Verify(agent.KeySet.Set, registerAgentReq.Signature, registerAgentReq.Thumbprint, registerAgentReq.Metadata)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not validate signature using previous keyset", logger.CapturedE(err))
|
|
|
|
api.ErrorResponse(w, http.StatusConflict, ErrCodeConflict, nil)
|
|
|
|
return
|
|
}
|
|
|
|
agent, err = s.agentRepo.Update(
|
|
ctx, agents[0].ID,
|
|
datastore.WithAgentUpdateKeySet(keySet),
|
|
datastore.WithAgentUpdateMetadata(metadata),
|
|
datastore.WithAgentUpdateThumbprint(registerAgentReq.Thumbprint),
|
|
)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not update agent", logger.CapturedE(err))
|
|
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
api.DataResponse(w, http.StatusCreated, struct {
|
|
Agent *datastore.Agent `json:"agent"`
|
|
}{
|
|
Agent: agent,
|
|
})
|
|
}
|
|
|
|
type updateAgentRequest struct {
|
|
Status *datastore.AgentStatus `json:"status" validate:"omitempty,oneof=0 1 2 3"`
|
|
Label *string `json:"label" validate:"omitempty"`
|
|
}
|
|
|
|
func (s *Server) updateAgent(w http.ResponseWriter, r *http.Request) {
|
|
agentID, ok := getAgentID(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
|
|
updateAgentReq := &updateAgentRequest{}
|
|
if ok := api.Bind(w, r, updateAgentReq); !ok {
|
|
return
|
|
}
|
|
|
|
options := make([]datastore.AgentUpdateOptionFunc, 0)
|
|
|
|
if updateAgentReq.Status != nil {
|
|
options = append(options, datastore.WithAgentUpdateStatus(*updateAgentReq.Status))
|
|
}
|
|
|
|
if updateAgentReq.Label != nil {
|
|
options = append(options, datastore.WithAgentUpdateLabel(*updateAgentReq.Label))
|
|
}
|
|
|
|
agent, err := s.agentRepo.Update(
|
|
ctx,
|
|
datastore.AgentID(agentID),
|
|
options...,
|
|
)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not update agent", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
api.DataResponse(w, http.StatusOK, struct {
|
|
Agent *datastore.Agent `json:"agent"`
|
|
}{
|
|
Agent: agent,
|
|
})
|
|
}
|
|
|
|
func (s *Server) queryAgents(w http.ResponseWriter, r *http.Request) {
|
|
limit, ok := getIntQueryParam(w, r, "limit", 10)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
offset, ok := getIntQueryParam(w, r, "offset", 0)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
options := []datastore.AgentQueryOptionFunc{
|
|
datastore.WithAgentQueryLimit(int(limit)),
|
|
datastore.WithAgentQueryOffset(int(offset)),
|
|
}
|
|
|
|
ids, ok := getIntSliceValues(w, r, "ids", nil)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if ids != nil {
|
|
agentIDs := func(ids []int64) []datastore.AgentID {
|
|
agentIDs := make([]datastore.AgentID, 0, len(ids))
|
|
for _, id := range ids {
|
|
agentIDs = append(agentIDs, datastore.AgentID(id))
|
|
}
|
|
|
|
return agentIDs
|
|
}(ids)
|
|
|
|
options = append(options, datastore.WithAgentQueryID(agentIDs...))
|
|
}
|
|
|
|
thumbprints, ok := getStringSliceValues(w, r, "thumbprints", nil)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if thumbprints != nil {
|
|
options = append(options, datastore.WithAgentQueryThumbprints(thumbprints...))
|
|
}
|
|
|
|
statuses, ok := getIntSliceValues(w, r, "statuses", nil)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if statuses != nil {
|
|
agentStatuses := func(statuses []int64) []datastore.AgentStatus {
|
|
agentStatuses := make([]datastore.AgentStatus, 0, len(statuses))
|
|
for _, status := range statuses {
|
|
agentStatuses = append(agentStatuses, datastore.AgentStatus(status))
|
|
}
|
|
|
|
return agentStatuses
|
|
}(statuses)
|
|
|
|
options = append(options, datastore.WithAgentQueryStatus(agentStatuses...))
|
|
}
|
|
|
|
ctx := r.Context()
|
|
|
|
agents, total, err := s.agentRepo.Query(
|
|
ctx,
|
|
options...,
|
|
)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not list agents", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
api.DataResponse(w, http.StatusOK, struct {
|
|
Agents []*datastore.Agent `json:"agents"`
|
|
Total int `json:"total"`
|
|
}{
|
|
Agents: agents,
|
|
Total: total,
|
|
})
|
|
}
|
|
|
|
func (s *Server) deleteAgent(w http.ResponseWriter, r *http.Request) {
|
|
agentID, ok := getAgentID(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
|
|
err := s.agentRepo.Delete(
|
|
ctx,
|
|
datastore.AgentID(agentID),
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, datastore.ErrNotFound) {
|
|
api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil)
|
|
|
|
return
|
|
}
|
|
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not delete agent", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
api.DataResponse(w, http.StatusOK, struct {
|
|
AgentID datastore.AgentID `json:"agentId"`
|
|
}{
|
|
AgentID: datastore.AgentID(agentID),
|
|
})
|
|
}
|
|
|
|
func (s *Server) getAgent(w http.ResponseWriter, r *http.Request) {
|
|
agentID, ok := getAgentID(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
|
|
agent, err := s.agentRepo.Get(
|
|
ctx,
|
|
datastore.AgentID(agentID),
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, datastore.ErrNotFound) {
|
|
api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil)
|
|
|
|
return
|
|
}
|
|
|
|
err = errors.WithStack(err)
|
|
logger.Error(ctx, "could not get agent", logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil)
|
|
|
|
return
|
|
}
|
|
|
|
api.DataResponse(w, http.StatusOK, struct {
|
|
Agent *datastore.Agent `json:"agent"`
|
|
}{
|
|
Agent: agent,
|
|
})
|
|
}
|
|
|
|
func getAgentID(w http.ResponseWriter, r *http.Request) (datastore.AgentID, bool) {
|
|
rawAgentID := chi.URLParam(r, "agentID")
|
|
|
|
agentID, err := strconv.ParseInt(rawAgentID, 10, 64)
|
|
if err != nil {
|
|
logger.Error(r.Context(), "could not parse agent id", logger.CapturedE(errors.WithStack(err)))
|
|
api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil)
|
|
|
|
return 0, false
|
|
}
|
|
|
|
return datastore.AgentID(agentID), true
|
|
}
|
|
|
|
func getIntQueryParam(w http.ResponseWriter, r *http.Request, param string, defaultValue int64) (int64, bool) {
|
|
rawValue := r.URL.Query().Get(param)
|
|
if rawValue != "" {
|
|
value, err := strconv.ParseInt(rawValue, 10, 64)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(r.Context(), "could not parse int param", logger.F("param", param), logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil)
|
|
|
|
return 0, false
|
|
}
|
|
|
|
return value, true
|
|
}
|
|
|
|
return defaultValue, true
|
|
}
|
|
|
|
func getStringSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []string) ([]string, bool) {
|
|
rawValue := r.URL.Query().Get(param)
|
|
if rawValue != "" {
|
|
values := strings.Split(rawValue, ",")
|
|
|
|
return values, true
|
|
}
|
|
|
|
return defaultValue, true
|
|
}
|
|
|
|
func getIntSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []int64) ([]int64, bool) {
|
|
rawValue := r.URL.Query().Get(param)
|
|
|
|
if rawValue != "" {
|
|
rawValues := strings.Split(rawValue, ",")
|
|
values := make([]int64, 0, len(rawValues))
|
|
|
|
for _, rv := range rawValues {
|
|
value, err := strconv.ParseInt(rv, 10, 64)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Error(r.Context(), "could not parse int slice param", logger.F("param", param), logger.CapturedE(err))
|
|
api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil)
|
|
|
|
return nil, false
|
|
}
|
|
|
|
values = append(values, value)
|
|
}
|
|
|
|
return values, true
|
|
}
|
|
|
|
return defaultValue, true
|
|
}
|