Redesign authentication protocol
This commit is contained in:
@ -1,158 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
"forge.cadoles.com/wpetit/go-http-peering/memory"
|
||||
)
|
||||
|
||||
func TestAdvertiseHandlerBadRequest(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusBadRequest; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peers, err := store.List()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := len(peers), 0; g != e {
|
||||
t.Errorf("len(peers): got '%v', expected '%v'", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerInvalidPublicKeyFormat(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peering.NewPeerID(),
|
||||
PublicKey: []byte("Test"),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusBadRequest; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peers, err := store.List()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := len(peers), 0; g != e {
|
||||
t.Errorf("len(peers): got '%v', expected '%v'", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerExistingPeer(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
pk := mustGeneratePrivateKey()
|
||||
pem, err := crypto.EncodePublicKeyToPEM(pk.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peerID := peering.NewPeerID()
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peerID,
|
||||
PublicKey: pem,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
req = httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusConflict; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerValidRequest(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
pk := mustGeneratePrivateKey()
|
||||
pem, err := crypto.EncodePublicKeyToPEM(pk.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peerID := peering.NewPeerID()
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peerID,
|
||||
PublicKey: pem,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusCreated; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := peer.PublicKey, advertising.PublicKey; !bytes.Equal(peer.PublicKey, advertising.PublicKey) {
|
||||
t.Errorf("peer.PublicKey: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
}
|
@ -1,13 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
peeringCrypto "forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -18,12 +19,26 @@ var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
)
|
||||
|
||||
func AdvertiseHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc {
|
||||
func AdvertiseHandler(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) http.HandlerFunc {
|
||||
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
serverToken := r.Header.Get(ServerTokenHeader)
|
||||
if serverToken == "" {
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
serverClaims, err := assertServerToken(key, serverToken)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
advertising := &peering.AdvertisingRequest{}
|
||||
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
@ -33,19 +48,13 @@ func AdvertiseHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
if !options.PeerIDValidator(advertising.ID) {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidAdvertisingRequest)
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := crypto.DecodePEMToPublicKey(advertising.PublicKey); err != nil {
|
||||
if _, err := peeringCrypto.DecodePEMToPublicKey(advertising.PublicKey); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := store.Get(advertising.ID)
|
||||
peer, err := store.Get(serverClaims.PeerID)
|
||||
|
||||
if err == nil {
|
||||
logger.Printf("[ERROR] %s", ErrPeerIDAlreadyInUse)
|
||||
@ -61,7 +70,7 @@ func AdvertiseHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc
|
||||
|
||||
attrs := filterAttributes(options.PeerAttributes, advertising.Attributes)
|
||||
|
||||
peer, err = store.Create(advertising.ID, attrs)
|
||||
peer, err = store.Create(serverClaims.PeerID, attrs)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
@ -74,6 +83,12 @@ func AdvertiseHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdateLastContact(peer.ID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdatePublicKey(peer.ID, advertising.PublicKey); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
@ -212,10 +227,6 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultPeerIDValidator(id peering.PeerID) bool {
|
||||
return string(id) != ""
|
||||
}
|
||||
|
||||
func filterAttributes(filters []string, attrs peering.PeerAttributes) peering.PeerAttributes {
|
||||
filtered := peering.PeerAttributes{}
|
||||
for _, key := range filters {
|
||||
|
@ -3,13 +3,14 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
peeringCrypto "forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
@ -18,8 +19,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
AuthorizationType = "Bearer"
|
||||
KeyPeerID ContextKey = "peerID"
|
||||
ServerTokenHeader = "X-Server-Token" // nolint: gosec
|
||||
ClientTokenHeader = "X-Client-Token"
|
||||
KeyPeerID ContextKey = "PeerID"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -30,100 +32,63 @@ var (
|
||||
|
||||
type ContextKey string
|
||||
|
||||
func Authenticate(store peering.Store, funcs ...OptionFunc) func(http.Handler) http.Handler {
|
||||
func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) func(http.Handler) http.Handler {
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
middleware := func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
if authorization == "" {
|
||||
serverToken := r.Header.Get(ServerTokenHeader)
|
||||
if serverToken == "" {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authorization, " ", 2)
|
||||
|
||||
if len(parts) != 2 || parts[0] != AuthorizationType {
|
||||
clientToken := r.Header.Get(ClientTokenHeader)
|
||||
if clientToken == "" {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(parts[1], &peering.PeerClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
claims, ok := token.Claims.(*peering.PeerClaims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
peerID := peering.PeerID(claims.Issuer)
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peer.Status == peering.StatusRejected {
|
||||
return nil, ErrPeerRejected
|
||||
}
|
||||
if peer.Status != peering.StatusPeered {
|
||||
return nil, ErrNotPeered
|
||||
}
|
||||
publicKey, err := crypto.DecodePEMToPublicKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
serverClaims, err := assertServerToken(key, serverToken)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
if err == ErrPeerRejected {
|
||||
sendError(w, http.StatusForbidden)
|
||||
} else {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
clientClaims, err := assertClientToken(serverClaims.PeerID, store, clientToken)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
if err == peering.ErrPeerNotFound {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
} else {
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*peering.PeerClaims)
|
||||
if !ok {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidClaims)
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
match, body, err := assertBodySum(r.Body, clientClaims.BodySum)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Body.Close(); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
match, err := compareChecksum(body, claims.BodySum)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if !match {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidChecksum)
|
||||
sendError(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peerID := peering.PeerID(claims.Issuer)
|
||||
|
||||
if err := store.UpdateLastContact(peerID, r.RemoteAddr, time.Now()); err != nil {
|
||||
if err := store.UpdateLastContact(serverClaims.PeerID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), KeyPeerID, peerID)
|
||||
ctx := context.WithValue(r.Context(), KeyPeerID, serverClaims.PeerID)
|
||||
r = r.WithContext(ctx)
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
@ -143,6 +108,71 @@ func GetPeerID(r *http.Request) (peering.PeerID, error) {
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerTokenClaims, error) {
|
||||
fn := func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
}
|
||||
token, err := jwt.ParseWithClaims(serverToken, &peering.ServerTokenClaims{}, fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
claims, ok := token.Claims.(*peering.ServerTokenClaims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string) (*peering.ClientTokenClaims, error) {
|
||||
fn := func(token *jwt.Token) (interface{}, error) {
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peer.Status == peering.StatusRejected {
|
||||
return nil, ErrPeerRejected
|
||||
}
|
||||
if peer.Status != peering.StatusPeered {
|
||||
return nil, ErrNotPeered
|
||||
}
|
||||
publicKey, err := peeringCrypto.DecodePEMToPublicKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return publicKey, nil
|
||||
}
|
||||
token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
claims, ok := token.Claims.(*peering.ClientTokenClaims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) {
|
||||
body, err := ioutil.ReadAll(rc)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
if err := rc.Close(); err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
match, err := compareChecksum(body, bodySum)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
return match, body, nil
|
||||
}
|
||||
|
||||
func sendError(w http.ResponseWriter, status int) {
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
}
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
@ -13,10 +11,9 @@ type Logger interface {
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
PeerAttributes []string
|
||||
ErrorHandler ErrorHandler
|
||||
PeerIDValidator func(peering.PeerID) bool
|
||||
Logger Logger
|
||||
PeerAttributes []string
|
||||
ErrorHandler ErrorHandler
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
@ -44,10 +41,9 @@ func WithErrorHandler(handler ErrorHandler) OptionFunc {
|
||||
func defaultOptions() *Options {
|
||||
logger := log.New(os.Stdout, "[go-http-peering] ", log.LstdFlags|log.Lshortfile)
|
||||
return &Options{
|
||||
PeerAttributes: []string{"Label"},
|
||||
ErrorHandler: DefaultErrorHandler,
|
||||
PeerIDValidator: DefaultPeerIDValidator,
|
||||
Logger: logger,
|
||||
PeerAttributes: []string{"Label"},
|
||||
ErrorHandler: DefaultErrorHandler,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user