go-http-peering/server/middleware.go

204 lines
4.9 KiB
Go

package server
import (
"bytes"
"context"
"crypto/rsa"
"crypto/sha256"
"errors"
"io"
"io/ioutil"
"time"
peeringCrypto "forge.cadoles.com/wpetit/go-http-peering/crypto"
peering "forge.cadoles.com/wpetit/go-http-peering"
jwt "github.com/dgrijalva/jwt-go"
"net/http"
)
const (
ServerTokenHeader = "X-Server-Token" // nolint: gosec
ClientTokenHeader = "X-Client-Token"
KeyPeerID ContextKey = "PeerID"
)
var (
ErrInvalidClaims = errors.New("invalid claims")
ErrInvalidChecksum = errors.New("invalid checksum")
ErrNotPeered = errors.New("not peered")
)
type ContextKey string
func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) func(http.Handler) http.Handler {
options := createOptions(funcs...)
logger := options.Logger
middleware := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
serverToken := r.Header.Get(ServerTokenHeader)
if serverToken == "" {
sendError(w, http.StatusUnauthorized)
return
}
clientToken := r.Header.Get(ClientTokenHeader)
if clientToken == "" {
sendError(w, http.StatusUnauthorized)
return
}
serverClaims, err := assertServerToken(key, serverToken)
if err != nil {
logger.Printf("[ERROR] %s", err)
sendError(w, http.StatusUnauthorized)
return
}
clientClaims, err := assertClientToken(serverClaims.PeerID, store, clientToken)
if err != nil {
logger.Printf("[ERROR] %s", err)
switch err {
case peering.ErrPeerNotFound:
sendError(w, http.StatusUnauthorized)
case ErrNotPeered:
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
logger.Printf("[ERROR] %s", err)
sendError(w, http.StatusInternalServerError)
return
}
sendError(w, http.StatusUnauthorized)
default:
sendError(w, http.StatusInternalServerError)
}
return
}
match, body, err := assertBodySum(r.Body, clientClaims.BodySum)
if err != nil {
logger.Printf("[ERROR] %s", err)
sendError(w, http.StatusInternalServerError)
return
}
if !match {
logger.Printf("[ERROR] %s", ErrInvalidChecksum)
sendError(w, http.StatusBadRequest)
return
}
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
logger.Printf("[ERROR] %s", err)
sendError(w, http.StatusInternalServerError)
return
}
ctx := context.WithValue(r.Context(), KeyPeerID, serverClaims.PeerID)
r = r.WithContext(ctx)
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return middleware
}
func GetPeerID(r *http.Request) (peering.PeerID, error) {
peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID)
if !ok {
return "", ErrUnauthorized
}
return peerID, nil
}
func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerTokenClaims, error) {
fn := func(token *jwt.Token) (interface{}, error) {
return key, nil
}
token, err := jwt.ParseWithClaims(serverToken, &peering.ServerTokenClaims{}, fn)
if err != nil {
validationError, ok := err.(*jwt.ValidationError)
if ok {
return nil, validationError.Inner
}
return nil, err
}
if !token.Valid {
return nil, ErrInvalidClaims
}
claims, ok := token.Claims.(*peering.ServerTokenClaims)
if !ok {
return nil, ErrInvalidClaims
}
return claims, nil
}
func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string) (*peering.ClientTokenClaims, error) {
fn := func(token *jwt.Token) (interface{}, error) {
peer, err := store.Get(peerID)
if err != nil {
return nil, err
}
if peer.Status == peering.StatusRejected {
return nil, ErrPeerRejected
}
if peer.Status != peering.StatusPeered {
return nil, ErrNotPeered
}
publicKey, err := peeringCrypto.DecodePEMToPublicKey(peer.PublicKey)
if err != nil {
return nil, err
}
return publicKey, nil
}
token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn)
if err != nil {
validationError, ok := err.(*jwt.ValidationError)
if ok {
return nil, validationError.Inner
}
return nil, err
}
if !token.Valid {
return nil, ErrInvalidClaims
}
claims, ok := token.Claims.(*peering.ClientTokenClaims)
if !ok {
return nil, ErrInvalidClaims
}
return claims, nil
}
func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) {
body, err := ioutil.ReadAll(rc)
if err != nil {
return false, nil, err
}
if err := rc.Close(); err != nil {
return false, nil, err
}
match, err := compareChecksum(body, bodySum)
if err != nil {
return false, nil, err
}
return match, body, nil
}
func sendError(w http.ResponseWriter, status int) {
http.Error(w, http.StatusText(status), status)
}
func compareChecksum(body []byte, sum []byte) (bool, error) {
sha := sha256.New()
_, err := sha.Write(body)
if err != nil {
return false, err
}
return bytes.Equal(sum, sha.Sum(nil)), nil
}