package server import ( "bytes" "context" "crypto/rsa" "crypto/sha256" "errors" "io" "io/ioutil" "time" peeringCrypto "forge.cadoles.com/wpetit/go-http-peering/crypto" peering "forge.cadoles.com/wpetit/go-http-peering" jwt "github.com/dgrijalva/jwt-go" "net/http" ) const ( ServerTokenHeader = "X-Server-Token" // nolint: gosec ClientTokenHeader = "X-Client-Token" KeyPeerID ContextKey = "PeerID" ) var ( ErrInvalidClaims = errors.New("invalid claims") ErrInvalidChecksum = errors.New("invalid checksum") ErrNotPeered = errors.New("not peered") ) type ContextKey string func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) func(http.Handler) http.Handler { options := createOptions(funcs...) logger := options.Logger middleware := func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { serverToken := r.Header.Get(ServerTokenHeader) if serverToken == "" { sendError(w, http.StatusUnauthorized) return } clientToken := r.Header.Get(ClientTokenHeader) if clientToken == "" { sendError(w, http.StatusUnauthorized) return } serverClaims, err := assertServerToken(key, serverToken) if err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusUnauthorized) return } clientClaims, err := assertClientToken(serverClaims.PeerID, store, clientToken) if err != nil { logger.Printf("[ERROR] %s", err) if err == peering.ErrPeerNotFound { sendError(w, http.StatusUnauthorized) } else { sendError(w, http.StatusInternalServerError) } return } match, body, err := assertBodySum(r.Body, clientClaims.BodySum) if err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusInternalServerError) return } if !match { logger.Printf("[ERROR] %s", ErrInvalidChecksum) sendError(w, http.StatusBadRequest) return } if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil { logger.Printf("[ERROR] %s", err) sendError(w, http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), KeyPeerID, serverClaims.PeerID) r = r.WithContext(ctx) r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } return middleware } func GetPeerID(r *http.Request) (peering.PeerID, error) { peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID) if !ok { return "", ErrUnauthorized } return peerID, nil } func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerTokenClaims, error) { fn := func(token *jwt.Token) (interface{}, error) { return key, nil } token, err := jwt.ParseWithClaims(serverToken, &peering.ServerTokenClaims{}, fn) if err != nil { return nil, err } if !token.Valid { return nil, ErrInvalidClaims } claims, ok := token.Claims.(*peering.ServerTokenClaims) if !ok { return nil, ErrInvalidClaims } return claims, nil } func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string) (*peering.ClientTokenClaims, error) { fn := func(token *jwt.Token) (interface{}, error) { peer, err := store.Get(peerID) if err != nil { return nil, err } if peer.Status == peering.StatusRejected { return nil, ErrPeerRejected } if peer.Status != peering.StatusPeered { return nil, ErrNotPeered } publicKey, err := peeringCrypto.DecodePEMToPublicKey(peer.PublicKey) if err != nil { return nil, err } return publicKey, nil } token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn) if err != nil { return nil, err } if !token.Valid { return nil, ErrInvalidClaims } claims, ok := token.Claims.(*peering.ClientTokenClaims) if !ok { return nil, ErrInvalidClaims } return claims, nil } func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) { body, err := ioutil.ReadAll(rc) if err != nil { return false, nil, err } if err := rc.Close(); err != nil { return false, nil, err } match, err := compareChecksum(body, bodySum) if err != nil { return false, nil, err } return match, body, nil } func sendError(w http.ResponseWriter, status int) { http.Error(w, http.StatusText(status), status) } func compareChecksum(body []byte, sum []byte) (bool, error) { sha := sha256.New() _, err := sha.Write(body) if err != nil { return false, err } return bytes.Equal(sum, sha.Sum(nil)), nil }