From 5f3f5083296e4ddcc8b26e886dba510849a9d8da Mon Sep 17 00:00:00 2001 From: William Petit Date: Fri, 17 May 2019 10:51:15 +0200 Subject: [PATCH] Allow ignoring of specific client token errors --- server/middleware.go | 45 ++++++++++++++++++++++++++++++++++++-------- server/option.go | 20 ++++++++++++++------ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/server/middleware.go b/server/middleware.go index f6ebf23..46277f9 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -58,7 +58,13 @@ func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) return } - clientClaims, err := assertClientToken(serverClaims.PeerID, store, clientToken) + clientClaims, err := assertClientToken( + serverClaims.PeerID, + store, + clientToken, + logger, + options.IgnoredClientTokenErrors..., + ) if err != nil { logger.Printf("[ERROR] %s", err) switch err { @@ -78,10 +84,11 @@ func Authenticate(store peering.Store, key *rsa.PublicKey, funcs ...OptionFunc) return } sendError(w, http.StatusForbidden) + return default: sendError(w, http.StatusInternalServerError) + return } - return } match, body, err := assertBodySum(r.Body, clientClaims.BodySum) @@ -145,7 +152,7 @@ func assertServerToken(key *rsa.PublicKey, serverToken string) (*peering.ServerT return claims, nil } -func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string) (*peering.ClientTokenClaims, error) { +func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken string, logger Logger, ignoredValidationErrors ...uint32) (*peering.ClientTokenClaims, error) { fn := func(token *jwt.Token) (interface{}, error) { peer, err := store.Get(peerID) if err != nil { @@ -163,22 +170,35 @@ func assertClientToken(peerID peering.PeerID, store peering.Store, clientToken s } return publicKey, nil } + + getPeeringClaims := func(token *jwt.Token) (*peering.ClientTokenClaims, error) { + claims, ok := token.Claims.(*peering.ClientTokenClaims) + if !ok { + return nil, ErrInvalidClaims + } + return claims, nil + } + token, err := jwt.ParseWithClaims(clientToken, &peering.ClientTokenClaims{}, fn) if err != nil { validationError, ok := err.(*jwt.ValidationError) if ok { + for _, c := range ignoredValidationErrors { + if validationError.Errors&c != 0 { + logger.Printf("ignoring token validation error: '%s'", validationError.Inner) + return getPeeringClaims(token) + } + } return nil, validationError.Inner } return nil, err } + if !token.Valid { return nil, ErrInvalidClaims } - claims, ok := token.Claims.(*peering.ClientTokenClaims) - if !ok { - return nil, ErrInvalidClaims - } - return claims, nil + + return getPeeringClaims(token) } func assertBodySum(rc io.ReadCloser, bodySum []byte) (bool, []byte, error) { @@ -200,6 +220,15 @@ func sendError(w http.ResponseWriter, status int) { http.Error(w, http.StatusText(status), status) } +func filterIgnoredValidationError(err *jwt.ValidationError, ignored ...uint32) error { + for _, c := range ignored { + if err.Errors&c != 0 { + return nil + } + } + return err +} + func compareChecksum(body []byte, sum []byte) (bool, error) { sha := sha256.New() _, err := sha.Write(body) diff --git a/server/option.go b/server/option.go index 0fa3934..bf74921 100644 --- a/server/option.go +++ b/server/option.go @@ -11,9 +11,10 @@ type Logger interface { } type Options struct { - PeerAttributes []string - ErrorHandler ErrorHandler - Logger Logger + PeerAttributes []string + ErrorHandler ErrorHandler + Logger Logger + IgnoredClientTokenErrors []uint32 } type OptionFunc func(*Options) @@ -38,12 +39,19 @@ func WithErrorHandler(handler ErrorHandler) OptionFunc { } } +func WithIgnoredClientTokenErrors(codes ...uint32) OptionFunc { + return func(options *Options) { + options.IgnoredClientTokenErrors = codes + } +} + func defaultOptions() *Options { logger := log.New(os.Stdout, "[go-http-peering] ", log.LstdFlags|log.Lshortfile) return &Options{ - PeerAttributes: []string{"Label"}, - ErrorHandler: DefaultErrorHandler, - Logger: logger, + PeerAttributes: []string{"Label"}, + ErrorHandler: DefaultErrorHandler, + Logger: logger, + IgnoredClientTokenErrors: make([]uint32, 0), } }