package server import ( "bytes" "context" "crypto/rsa" "crypto/sha256" "io" "io/ioutil" "net/http" "time" peeringCrypto "forge.cadoles.com/Cadoles/go-http-peering/crypto" peering "forge.cadoles.com/Cadoles/go-http-peering" jwt "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" ) const ( ServerTokenHeader = "X-Server-Token" // nolint: gosec ClientTokenHeader = "X-Client-Token" KeyPeerID ContextKey = "PeerID" ) var ( ErrInvalidClaims = errors.New("invalid claims") ErrInvalidChecksum = errors.New("invalid checksum") ErrNotPeered = errors.New("not peered") ) type ContextKey string func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) func(http.Handler) http.Handler { options := createOptions(funcs...) logger := options.Logger middleware := func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { serverToken := r.Header.Get(ServerTokenHeader) if serverToken == "" { sendError(w, http.StatusUnauthorized) return } clientToken := r.Header.Get(ClientTokenHeader) if clientToken == "" { sendError(w, http.StatusUnauthorized) return } serverClaims, err := assertServerToken(key, serverToken) if err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) sendError(w, http.StatusUnauthorized) return } clientClaims, err := assertClientToken( serverClaims.PeerID, store, clientToken, logger, options.IgnoredClientTokenErrors..., ) if err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) switch err { case peering.ErrPeerNotFound: sendError(w, http.StatusUnauthorized) case ErrNotPeered: if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) sendError(w, http.StatusInternalServerError) return } sendError(w, http.StatusUnauthorized) return case ErrPeerRejected: if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) sendError(w, http.StatusInternalServerError) return } sendError(w, http.StatusForbidden) return default: sendError(w, http.StatusInternalServerError) return } } match, body, err := assertBodySum(r.Body, clientClaims.BodySum) if err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) sendError(w, http.StatusInternalServerError) return } if !match { logger.Printf("[ERROR] %+v", errors.WithStack(ErrInvalidChecksum)) sendError(w, http.StatusBadRequest) return } if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil { logger.Printf("[ERROR] %+v", errors.WithStack(err)) sendError(w, http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), KeyPeerID, serverClaims.PeerID) r = r.WithContext(ctx) r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } return middleware } func GetPeerID(r *http.Request) (peering.PeerID, error) { peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID) if !ok { return "", ErrUnauthorized } return peerID, nil } func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerTokenClaims, error) { fn := func(token *jwt.Token) (interface{}, error) { return key, nil } token, err := jwt.ParseWithClaims(serverToken, &peering.ServerTokenClaims{}, fn) if err != nil { validationError, ok := err.(*jwt.ValidationError) if ok { return nil, validationError.Inner } return nil, err } if !token.Valid { return nil, ErrInvalidClaims } claims, ok := token.Claims.(*peering.ServerTokenClaims) if !ok { return nil, ErrInvalidClaims } return claims, nil } func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string, logger Logger, ignoredValidationErrors ...uint32) (*peering.ClientTokenClaims, error) { fn := func(token *jwt.Token) (interface{}, error) { peer, err := store.Get(peerID) if err != nil { return nil, err } if peer.Status == peering.StatusRejected { return nil, ErrPeerRejected } if peer.Status != peering.StatusPeered { return nil, ErrNotPeered } publicKey, err := peeringCrypto.DecodePEMToPublicKey(peer.PublicKey) if err != nil { return nil, err } return publicKey, nil } getPeeringClaims := func(token *jwt.Token) (*peering.ClientTokenClaims, error) { claims, ok := token.Claims.(*peering.ClientTokenClaims) if !ok { return nil, ErrInvalidClaims } return claims, nil } token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn) if err != nil { validationError, ok := err.(*jwt.ValidationError) if ok { for _, c := range ignoredValidationErrors { if validationError.Errors&c != 0 { logger.Printf("ignoring token validation error: '%+v'", errors.WithStack(validationError.Inner)) return getPeeringClaims(token) } } return nil, validationError.Inner } return nil, err } if !token.Valid { return nil, ErrInvalidClaims } return getPeeringClaims(token) } func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) { body, err := ioutil.ReadAll(rc) if err != nil { return false, nil, err } if err := rc.Close(); err != nil { return false, nil, err } match, err := compareChecksum(body, bodySum) if err != nil { return false, nil, err } return match, body, nil } func sendError(w http.ResponseWriter, status int) { http.Error(w, http.StatusText(status), status) } func filterIgnoredValidationError(err *jwt.ValidationError, ignored ...uint32) error { for _, c := range ignored { if err.Errors&c != 0 { return nil } } return err } func compareChecksum(body []byte, sum []byte) (bool, error) { sha := sha256.New() _, err := sha.Write(body) if err != nil { return false, err } return bytes.Equal(sum, sha.Sum(nil)), nil }