package server import ( "encoding/json" "net/http" "strconv" "strings" "forge.cadoles.com/Cadoles/emissary/internal/agent/metadata" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "forge.cadoles.com/Cadoles/emissary/internal/jwk" "github.com/go-chi/chi" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) const ( ErrCodeUnknownError api.ErrorCode = "unknown-error" ErrCodeNotFound api.ErrorCode = "not-found" ErrInvalidSignature api.ErrorCode = "invalid-signature" ) type registerAgentRequest struct { KeySet json.RawMessage `json:"keySet" validate:"required"` Metadata []metadata.Tuple `json:"metadata" validate:"required"` Thumbprint string `json:"thumbprint" validate:"required"` Signature string `json:"signature" validate:"required"` } func (s *Server) registerAgent(w http.ResponseWriter, r *http.Request) { registerAgentReq := ®isterAgentRequest{} if ok := api.Bind(w, r, registerAgentReq); !ok { return } ctx := r.Context() keySet, err := jwk.Parse(registerAgentReq.KeySet) if err != nil { logger.Error(ctx, "could not parse key set", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } ctx = logger.With(ctx, logger.F("agentThumbprint", registerAgentReq.Thumbprint)) validSignature, err := jwk.Verify(keySet, registerAgentReq.Signature, registerAgentReq.Thumbprint, registerAgentReq.Metadata) if err != nil { logger.Error(ctx, "could not validate signature", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } if !validSignature { logger.Error(ctx, "invalid signature", logger.F("signature", registerAgentReq.Signature)) api.ErrorResponse(w, http.StatusBadRequest, ErrInvalidSignature, nil) return } metadata := metadata.FromSorted(registerAgentReq.Metadata) agent, err := s.agentRepo.Create( ctx, registerAgentReq.Thumbprint, keySet, metadata, ) if err != nil { if !errors.Is(err, datastore.ErrAlreadyExist) { logger.Error(ctx, "could not create agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } agents, _, err := s.agentRepo.Query( ctx, datastore.WithAgentQueryThumbprints(registerAgentReq.Thumbprint), datastore.WithAgentQueryLimit(1), ) if err != nil { logger.Error(ctx, "could not retrieve agents", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } if len(agents) == 0 { logger.Error(ctx, "could not retrieve matching agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeNotFound, nil) return } agent, err = s.agentRepo.Update( ctx, agents[0].ID, datastore.WithAgentUpdateKeySet(keySet), datastore.WithAgentUpdateMetadata(metadata), datastore.WithAgentUpdateThumbprint(registerAgentReq.Thumbprint), ) if err != nil { logger.Error(ctx, "could not update agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } } api.DataResponse(w, http.StatusCreated, struct { Agent *datastore.Agent `json:"agent"` }{ Agent: agent, }) } type updateAgentRequest struct { Status *datastore.AgentStatus `json:"status" validate:"omitempty,oneof=0 1 2 3"` } func (s *Server) updateAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() updateAgentReq := &updateAgentRequest{} if ok := api.Bind(w, r, updateAgentReq); !ok { return } options := make([]datastore.AgentUpdateOptionFunc, 0) if updateAgentReq.Status != nil { options = append(options, datastore.WithAgentUpdateStatus(*updateAgentReq.Status)) } agent, err := s.agentRepo.Update( ctx, datastore.AgentID(agentID), options..., ) if err != nil { logger.Error(ctx, "could not update agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agent *datastore.Agent `json:"agent"` }{ Agent: agent, }) } func (s *Server) queryAgents(w http.ResponseWriter, r *http.Request) { limit, ok := getIntQueryParam(w, r, "limit", 10) if !ok { return } offset, ok := getIntQueryParam(w, r, "offset", 0) if !ok { return } options := []datastore.AgentQueryOptionFunc{ datastore.WithAgentQueryLimit(int(limit)), datastore.WithAgentQueryOffset(int(offset)), } ids, ok := getIntSliceValues(w, r, "ids", nil) if !ok { return } if ids != nil { agentIDs := func(ids []int64) []datastore.AgentID { agentIDs := make([]datastore.AgentID, 0, len(ids)) for _, id := range ids { agentIDs = append(agentIDs, datastore.AgentID(id)) } return agentIDs }(ids) options = append(options, datastore.WithAgentQueryID(agentIDs...)) } thumbprints, ok := getStringSliceValues(w, r, "thumbprints", nil) if !ok { return } if thumbprints != nil { options = append(options, datastore.WithAgentQueryThumbprints(thumbprints...)) } statuses, ok := getIntSliceValues(w, r, "statuses", nil) if !ok { return } if statuses != nil { agentStatuses := func(statuses []int64) []datastore.AgentStatus { agentStatuses := make([]datastore.AgentStatus, 0, len(statuses)) for _, status := range statuses { agentStatuses = append(agentStatuses, datastore.AgentStatus(status)) } return agentStatuses }(statuses) options = append(options, datastore.WithAgentQueryStatus(agentStatuses...)) } ctx := r.Context() agents, total, err := s.agentRepo.Query( ctx, options..., ) if err != nil { logger.Error(ctx, "could not list agents", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agents []*datastore.Agent `json:"agents"` Total int `json:"total"` }{ Agents: agents, Total: total, }) } func (s *Server) deleteAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() err := s.agentRepo.Delete( ctx, datastore.AgentID(agentID), ) if err != nil { if errors.Is(err, datastore.ErrNotFound) { api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil) return } logger.Error(ctx, "could not delete agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { AgentID datastore.AgentID `json:"agentId"` }{ AgentID: datastore.AgentID(agentID), }) } func (s *Server) getAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() agent, err := s.agentRepo.Get( ctx, datastore.AgentID(agentID), ) if err != nil { if errors.Is(err, datastore.ErrNotFound) { api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil) return } logger.Error(ctx, "could not get agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agent *datastore.Agent `json:"agent"` }{ Agent: agent, }) } func getAgentID(w http.ResponseWriter, r *http.Request) (datastore.AgentID, bool) { rawAgentID := chi.URLParam(r, "agentID") agentID, err := strconv.ParseInt(rawAgentID, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse agent id", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return 0, false } return datastore.AgentID(agentID), true } func getIntQueryParam(w http.ResponseWriter, r *http.Request, param string, defaultValue int64) (int64, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { value, err := strconv.ParseInt(rawValue, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse int param", logger.F("param", param), logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return 0, false } return value, true } return defaultValue, true } func getStringSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []string) ([]string, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { values := strings.Split(rawValue, ",") return values, true } return defaultValue, true } func getIntSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []int64) ([]int64, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { rawValues := strings.Split(rawValue, ",") values := make([]int64, 0, len(rawValues)) for _, rv := range rawValues { value, err := strconv.ParseInt(rv, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse int slice param", logger.F("param", param), logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return nil, false } values = append(values, value) } return values, true } return defaultValue, true }