package server import ( "bytes" "context" "crypto/sha256" "errors" "io/ioutil" "strings" "time" "forge.cadoles.com/wpetit/go-http-peering/crypto" peering "forge.cadoles.com/wpetit/go-http-peering" jwt "github.com/dgrijalva/jwt-go" "net/http" ) const ( AuthorizationType = "Bearer" KeyPeerID ContextKey = "peerID" ) var ( ErrInvalidClaims = errors.New("invalid claims") ErrInvalidChecksum = errors.New("invalid checksum") ErrNotPeered = errors.New("not peered") ) type ContextKey string func Authenticate(store peering.Store, funcs ...OptionFunc) func(http.Handler) http.Handler { options := createOptions(funcs...) logger := options.Logger middleware := func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { authorization := r.Header.Get("Authorization") if authorization == "" { sendError(w, http.StatusUnauthorized) return } parts := strings.SplitN(authorization, " ", 2) if len(parts) != 2 || parts[0] != AuthorizationType { sendError(w, http.StatusUnauthorized) return } token, err := jwt.ParseWithClaims(parts[1], &peering.PeerClaims{}, func(token *jwt.Token) (interface{}, error) { claims, ok := token.Claims.(*peering.PeerClaims) if !ok { return nil, ErrInvalidClaims } peerID := peering.PeerID(claims.Issuer) peer, err := store.Get(peerID) if err != nil { return nil, err } if peer.Status == peering.StatusRejected { return nil, ErrPeerRejected } if peer.Status != peering.StatusPeered { return nil, ErrNotPeered } publicKey, err := crypto.DecodePEMToPublicKey(peer.PublicKey) if err != nil { return nil, err } return publicKey, nil }) if err != nil || !token.Valid { logger.Printf("[ERROR] %s", err) if err == ErrPeerRejected { sendError(w, http.StatusForbidden) } else { sendError(w, http.StatusUnauthorized) } return } claims, ok := token.Claims.(*peering.PeerClaims) if !ok { logger.Printf("[ERROR] %s", ErrInvalidClaims) sendError(w, http.StatusUnauthorized) return } body, err := ioutil.ReadAll(r.Body) if err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusInternalServerError) return } if err := r.Body.Close(); err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusInternalServerError) return } match, err := compareChecksum(body, claims.BodySum) if err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusUnauthorized) return } if !match { logger.Printf("[ERROR] %s", ErrInvalidChecksum) sendError(w, http.StatusBadRequest) return } peerID := peering.PeerID(claims.Issuer) if err := store.UpdateLastContact(peerID, r.RemoteAddr, time.Now()); err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), KeyPeerID, peerID) r = r.WithContext(ctx) r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } return middleware } func GetPeerID(r *http.Request) (peering.PeerID, error) { peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID) if !ok { return "", ErrUnauthorized } return peerID, nil } func sendError(w http.ResponseWriter, status int) { http.Error(w, http.StatusText(status), status) } func compareChecksum(body []byte, sum []byte) (bool, error) { sha := sha256.New() _, err := sha.Write(body) if err != nil { return false, err } return bytes.Equal(sum, sha.Sum(nil)), nil }