240 lines
5.8 KiB
Go
240 lines
5.8 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"io"
|
|
"io/ioutil"
|
|
"time"
|
|
|
|
peeringCrypto "forge.cadoles.com/Cadoles/go-http-peering/crypto"
|
|
|
|
peering "forge.cadoles.com/Cadoles/go-http-peering"
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
|
|
"net/http"
|
|
)
|
|
|
|
const (
|
|
ServerTokenHeader = "X-Server-Token" // nolint: gosec
|
|
ClientTokenHeader = "X-Client-Token"
|
|
KeyPeerID ContextKey = "PeerID"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidClaims = errors.New("invalid claims")
|
|
ErrInvalidChecksum = errors.New("invalid checksum")
|
|
ErrNotPeered = errors.New("not peered")
|
|
)
|
|
|
|
type ContextKey string
|
|
|
|
func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) func(http.Handler) http.Handler {
|
|
options := createOptions(funcs...)
|
|
logger := options.Logger
|
|
|
|
middleware := func(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
serverToken := r.Header.Get(ServerTokenHeader)
|
|
if serverToken == "" {
|
|
sendError(w, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
clientToken := r.Header.Get(ClientTokenHeader)
|
|
if clientToken == "" {
|
|
sendError(w, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
serverClaims, err := assertServerToken(key, serverToken)
|
|
if err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
sendError(w, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
clientClaims, err := assertClientToken(
|
|
serverClaims.PeerID,
|
|
store,
|
|
clientToken,
|
|
logger,
|
|
options.IgnoredClientTokenErrors...,
|
|
)
|
|
if err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
switch err {
|
|
case peering.ErrPeerNotFound:
|
|
sendError(w, http.StatusUnauthorized)
|
|
case ErrNotPeered:
|
|
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
sendError(w, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
sendError(w, http.StatusUnauthorized)
|
|
case ErrPeerRejected:
|
|
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
sendError(w, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
sendError(w, http.StatusForbidden)
|
|
return
|
|
default:
|
|
sendError(w, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
match, body, err := assertBodySum(r.Body, clientClaims.BodySum)
|
|
if err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
sendError(w, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if !match {
|
|
logger.Printf("[ERROR] %s", ErrInvalidChecksum)
|
|
sendError(w, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
|
|
logger.Printf("[ERROR] %s", err)
|
|
sendError(w, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), KeyPeerID, serverClaims.PeerID)
|
|
r = r.WithContext(ctx)
|
|
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
}
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
return middleware
|
|
}
|
|
|
|
func GetPeerID(r *http.Request) (peering.PeerID, error) {
|
|
peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID)
|
|
if !ok {
|
|
return "", ErrUnauthorized
|
|
}
|
|
return peerID, nil
|
|
}
|
|
|
|
func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerTokenClaims, error) {
|
|
fn := func(token *jwt.Token) (interface{}, error) {
|
|
return key, nil
|
|
}
|
|
token, err := jwt.ParseWithClaims(serverToken, &peering.ServerTokenClaims{}, fn)
|
|
if err != nil {
|
|
validationError, ok := err.(*jwt.ValidationError)
|
|
if ok {
|
|
return nil, validationError.Inner
|
|
}
|
|
return nil, err
|
|
}
|
|
if !token.Valid {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
claims, ok := token.Claims.(*peering.ServerTokenClaims)
|
|
if !ok {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string, logger Logger, ignoredValidationErrors ...uint32) (*peering.ClientTokenClaims, error) {
|
|
fn := func(token *jwt.Token) (interface{}, error) {
|
|
peer, err := store.Get(peerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if peer.Status == peering.StatusRejected {
|
|
return nil, ErrPeerRejected
|
|
}
|
|
if peer.Status != peering.StatusPeered {
|
|
return nil, ErrNotPeered
|
|
}
|
|
publicKey, err := peeringCrypto.DecodePEMToPublicKey(peer.PublicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return publicKey, nil
|
|
}
|
|
|
|
getPeeringClaims := func(token *jwt.Token) (*peering.ClientTokenClaims, error) {
|
|
claims, ok := token.Claims.(*peering.ClientTokenClaims)
|
|
if !ok {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn)
|
|
if err != nil {
|
|
validationError, ok := err.(*jwt.ValidationError)
|
|
if ok {
|
|
for _, c := range ignoredValidationErrors {
|
|
if validationError.Errors&c != 0 {
|
|
logger.Printf("ignoring token validation error: '%s'", validationError.Inner)
|
|
return getPeeringClaims(token)
|
|
}
|
|
}
|
|
return nil, validationError.Inner
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
return getPeeringClaims(token)
|
|
}
|
|
|
|
func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) {
|
|
body, err := ioutil.ReadAll(rc)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
if err := rc.Close(); err != nil {
|
|
return false, nil, err
|
|
}
|
|
match, err := compareChecksum(body, bodySum)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
return match, body, nil
|
|
}
|
|
|
|
func sendError(w http.ResponseWriter, status int) {
|
|
http.Error(w, http.StatusText(status), status)
|
|
}
|
|
|
|
func filterIgnoredValidationError(err *jwt.ValidationError, ignored ...uint32) error {
|
|
for _, c := range ignored {
|
|
if err.Errors&c != 0 {
|
|
return nil
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func compareChecksum(body []byte, sum []byte) (bool, error) {
|
|
sha := sha256.New()
|
|
_, err := sha.Write(body)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return bytes.Equal(sum, sha.Sum(nil)), nil
|
|
}
|