package server import ( "net/http" "strconv" "strings" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "github.com/go-chi/chi" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/logger" ) const ( ErrCodeUnknownError api.ErrorCode = "unknown-error" ErrCodeNotFound api.ErrorCode = "not-found" ErrCodeAlreadyRegistered api.ErrorCode = "already-registered" ) type registerAgentRequest struct { RemoteID string } func (s *Server) registerAgent(w http.ResponseWriter, r *http.Request) { registerAgentReq := ®isterAgentRequest{} if ok := api.Bind(w, r, registerAgentReq); !ok { return } ctx := r.Context() agent, err := s.agentRepo.Create( ctx, registerAgentReq.RemoteID, datastore.AgentStatusPending, ) if err != nil { if errors.Is(err, datastore.ErrAlreadyExist) { logger.Error(ctx, "agent already registered", logger.F("remoteID", registerAgentReq.RemoteID)) api.ErrorResponse(w, http.StatusConflict, ErrCodeAlreadyRegistered, nil) return } logger.Error(ctx, "could not create agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusCreated, struct { Agent *datastore.Agent }{ Agent: agent, }) } type updateAgentRequest struct { Status *datastore.AgentStatus } func (s *Server) updateAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() updateAgentReq := &updateAgentRequest{} if ok := api.Bind(w, r, updateAgentReq); !ok { return } options := make([]datastore.AgentUpdateOptionFunc, 0) if updateAgentReq.Status != nil { options = append(options, datastore.WithAgentUpdateStatus(*updateAgentReq.Status)) } agent, err := s.agentRepo.Update( ctx, datastore.AgentID(agentID), options..., ) if err != nil { logger.Error(ctx, "could not update agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agent *datastore.Agent }{ Agent: agent, }) } func (s *Server) queryAgents(w http.ResponseWriter, r *http.Request) { limit, ok := getIntQueryParam(w, r, "limit", 10) if !ok { return } offset, ok := getIntQueryParam(w, r, "offset", 0) if !ok { return } options := []datastore.AgentQueryOptionFunc{ datastore.WithAgentQueryLimit(int(limit)), datastore.WithAgentQueryOffset(int(offset)), } ids, ok := getIntSliceValues(w, r, "ids", nil) if !ok { return } if ids != nil { agentIDs := func(ids []int64) []datastore.AgentID { agentIDs := make([]datastore.AgentID, 0, len(ids)) for _, id := range ids { agentIDs = append(agentIDs, datastore.AgentID(id)) } return agentIDs }(ids) options = append(options, datastore.WithAgentQueryID(agentIDs...)) } remoteIDs, ok := getStringSliceValues(w, r, "remote_ids", nil) if !ok { return } if remoteIDs != nil { options = append(options, datastore.WithAgentQueryRemoteID(remoteIDs...)) } statuses, ok := getIntSliceValues(w, r, "statuses", nil) if !ok { return } if statuses != nil { agentStatuses := func(statuses []int64) []datastore.AgentStatus { agentStatuses := make([]datastore.AgentStatus, 0, len(statuses)) for _, status := range statuses { agentStatuses = append(agentStatuses, datastore.AgentStatus(status)) } return agentStatuses }(statuses) options = append(options, datastore.WithAgentQueryStatus(agentStatuses...)) } ctx := r.Context() agents, total, err := s.agentRepo.Query( ctx, options..., ) if err != nil { logger.Error(ctx, "could not list agents", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agents []*datastore.Agent Total int }{ Agents: agents, Total: total, }) } func (s *Server) deleteAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() err := s.agentRepo.Delete( ctx, datastore.AgentID(agentID), ) if err != nil { if errors.Is(err, datastore.ErrNotFound) { api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil) return } logger.Error(ctx, "could not delete agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { AgentID datastore.AgentID }{ AgentID: datastore.AgentID(agentID), }) } func (s *Server) getAgent(w http.ResponseWriter, r *http.Request) { agentID, ok := getAgentID(w, r) if !ok { return } ctx := r.Context() agent, err := s.agentRepo.Get( ctx, datastore.AgentID(agentID), ) if err != nil { if errors.Is(err, datastore.ErrNotFound) { api.ErrorResponse(w, http.StatusNotFound, ErrCodeNotFound, nil) return } logger.Error(ctx, "could not get agent", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusInternalServerError, ErrCodeUnknownError, nil) return } api.DataResponse(w, http.StatusOK, struct { Agent *datastore.Agent }{ Agent: agent, }) } func getAgentID(w http.ResponseWriter, r *http.Request) (datastore.AgentID, bool) { rawAgentID := chi.URLParam(r, "agentID") agentID, err := strconv.ParseInt(rawAgentID, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse agent id", logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return 0, false } return datastore.AgentID(agentID), true } func getIntQueryParam(w http.ResponseWriter, r *http.Request, param string, defaultValue int64) (int64, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { value, err := strconv.ParseInt(rawValue, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse int param", logger.F("param", param), logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return 0, false } return value, true } return defaultValue, true } func getStringSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []string) ([]string, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { values := strings.Split(rawValue, ",") return values, true } return defaultValue, true } func getIntSliceValues(w http.ResponseWriter, r *http.Request, param string, defaultValue []int64) ([]int64, bool) { rawValue := r.URL.Query().Get(param) if rawValue != "" { rawValues := strings.Split(rawValue, ",") values := make([]int64, 0, len(rawValues)) for _, rv := range rawValues { value, err := strconv.ParseInt(rv, 10, 64) if err != nil { logger.Error(r.Context(), "could not parse int slice param", logger.F("param", param), logger.E(errors.WithStack(err))) api.ErrorResponse(w, http.StatusBadRequest, api.ErrCodeMalformedRequest, nil) return nil, false } values = append(values, value) } return values, true } return defaultValue, true }