Initial commit
This commit is contained in:
158
server/advertise_test.go
Normal file
158
server/advertise_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
"forge.cadoles.com/wpetit/go-http-peering/memory"
|
||||
)
|
||||
|
||||
func TestAdvertiseHandlerBadRequest(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusBadRequest; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peers, err := store.List()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := len(peers), 0; g != e {
|
||||
t.Errorf("len(peers): got '%v', expected '%v'", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerInvalidPublicKeyFormat(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peering.NewPeerID(),
|
||||
PublicKey: []byte("Test"),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusBadRequest; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peers, err := store.List()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := len(peers), 0; g != e {
|
||||
t.Errorf("len(peers): got '%v', expected '%v'", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerExistingPeer(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
pk := mustGeneratePrivateKey()
|
||||
pem, err := crypto.EncodePublicKeyToPEM(pk.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peerID := peering.NewPeerID()
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peerID,
|
||||
PublicKey: pem,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
req = httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusConflict; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAdvertiseHandlerValidRequest(t *testing.T) {
|
||||
store := memory.NewStore()
|
||||
handler := AdvertiseHandler(store)
|
||||
|
||||
pk := mustGeneratePrivateKey()
|
||||
pem, err := crypto.EncodePublicKeyToPEM(pk.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peerID := peering.NewPeerID()
|
||||
|
||||
advertising := &peering.AdvertisingRequest{
|
||||
ID: peerID,
|
||||
PublicKey: pem,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(advertising)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", peering.AdvertisePath, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
||||
if g, e := res.StatusCode, http.StatusCreated; g != e {
|
||||
t.Errorf("res.StatusCode: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if g, e := peer.PublicKey, advertising.PublicKey; !bytes.Equal(peer.PublicKey, advertising.PublicKey) {
|
||||
t.Errorf("peer.PublicKey: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
}
|
227
server/handler.go
Normal file
227
server/handler.go
Normal file
@ -0,0 +1,227 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAdvertisingRequest = errors.New("invalid advertising request")
|
||||
ErrInvalidUpdateRequest = errors.New("invalid update request")
|
||||
ErrPeerRejected = errors.New("peer rejected")
|
||||
ErrPeerIDAlreadyInUse = errors.New("peer id already in use")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
)
|
||||
|
||||
func AdvertiseHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc {
|
||||
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
advertising := &peering.AdvertisingRequest{}
|
||||
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(advertising); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !options.PeerIDValidator(advertising.ID) {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidAdvertisingRequest)
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := crypto.DecodePEMToPublicKey(advertising.PublicKey); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, ErrInvalidAdvertisingRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := store.Get(advertising.ID)
|
||||
|
||||
if err == nil {
|
||||
logger.Printf("[ERROR] %s", ErrPeerIDAlreadyInUse)
|
||||
options.ErrorHandler(w, r, ErrPeerIDAlreadyInUse)
|
||||
return
|
||||
}
|
||||
|
||||
if err != peering.ErrPeerNotFound {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
attrs := filterAttributes(options.PeerAttributes, advertising.Attributes)
|
||||
|
||||
peer, err = store.Create(advertising.ID, attrs)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdateLastContact(peer.ID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdatePublicKey(peer.ID, advertising.PublicKey); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func UpdateHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc {
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
update := &peering.UpdateRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(update); err != nil {
|
||||
options.ErrorHandler(w, r, ErrInvalidUpdateRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := GetPeerID(r)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
logger.Printf("[ERROR] %s", ErrUnauthorized)
|
||||
options.ErrorHandler(w, r, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.Status == peering.StatusRejected {
|
||||
logger.Printf("[ERROR] %s", ErrPeerRejected)
|
||||
options.ErrorHandler(w, r, ErrPeerRejected)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdateLastContact(peer.ID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
attrs := filterAttributes(options.PeerAttributes, update.Attributes)
|
||||
if err := store.UpdateAttributes(peer.ID, attrs); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func PingHandler(store peering.Store, funcs ...OptionFunc) http.HandlerFunc {
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
update := &peering.UpdateRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(update); err != nil {
|
||||
options.ErrorHandler(w, r, ErrInvalidUpdateRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := GetPeerID(r)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
logger.Printf("[ERROR] %s", ErrUnauthorized)
|
||||
options.ErrorHandler(w, r, ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.Status == peering.StatusRejected {
|
||||
logger.Printf("[ERROR] %s", ErrPeerRejected)
|
||||
options.ErrorHandler(w, r, ErrPeerRejected)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.UpdateLastContact(peer.ID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
options.ErrorHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch err {
|
||||
case ErrInvalidAdvertisingRequest:
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case ErrPeerIDAlreadyInUse:
|
||||
http.Error(w, http.StatusText(http.StatusConflict), http.StatusConflict)
|
||||
case ErrUnauthorized:
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
case ErrPeerRejected:
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
default:
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultPeerIDValidator(id peering.PeerID) bool {
|
||||
return string(id) != ""
|
||||
}
|
||||
|
||||
func filterAttributes(filters []string, attrs peering.PeerAttributes) peering.PeerAttributes {
|
||||
filtered := peering.PeerAttributes{}
|
||||
for _, key := range filters {
|
||||
if _, exists := attrs[key]; exists {
|
||||
filtered[key] = attrs[key]
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
157
server/middleware.go
Normal file
157
server/middleware.go
Normal file
@ -0,0 +1,157 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-http-peering/crypto"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthorizationType = "Bearer"
|
||||
KeyPeerID ContextKey = "peerID"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidClaims = errors.New("invalid claims")
|
||||
ErrInvalidChecksum = errors.New("invalid checksum")
|
||||
ErrNotPeered = errors.New("not peered")
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
|
||||
func Authenticate(store peering.Store, funcs ...OptionFunc) func(http.Handler) http.Handler {
|
||||
options := createOptions(funcs...)
|
||||
logger := options.Logger
|
||||
|
||||
middleware := func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
if authorization == "" {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authorization, " ", 2)
|
||||
|
||||
if len(parts) != 2 || parts[0] != AuthorizationType {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(parts[1], &peering.PeerClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
claims, ok := token.Claims.(*peering.PeerClaims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
peerID := peering.PeerID(claims.Issuer)
|
||||
peer, err := store.Get(peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peer.Status == peering.StatusRejected {
|
||||
return nil, ErrPeerRejected
|
||||
}
|
||||
if peer.Status != peering.StatusPeered {
|
||||
return nil, ErrNotPeered
|
||||
}
|
||||
publicKey, err := crypto.DecodePEMToPublicKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
if err == ErrPeerRejected {
|
||||
sendError(w, http.StatusForbidden)
|
||||
} else {
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*peering.PeerClaims)
|
||||
if !ok {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidClaims)
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Body.Close(); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
match, err := compareChecksum(body, claims.BodySum)
|
||||
if err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if !match {
|
||||
logger.Printf("[ERROR] %s", ErrInvalidChecksum)
|
||||
sendError(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
peerID := peering.PeerID(claims.Issuer)
|
||||
|
||||
if err := store.UpdateLastContact(peerID, r.RemoteAddr, time.Now()); err != nil {
|
||||
logger.Printf("[ERROR] %s", err)
|
||||
sendError(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), KeyPeerID, peerID)
|
||||
r = r.WithContext(ctx)
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
return middleware
|
||||
}
|
||||
|
||||
func GetPeerID(r *http.Request) (peering.PeerID, error) {
|
||||
peerID, ok := r.Context().Value(KeyPeerID).(peering.PeerID)
|
||||
if !ok {
|
||||
return "", ErrUnauthorized
|
||||
}
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func sendError(w http.ResponseWriter, status int) {
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
}
|
||||
|
||||
func compareChecksum(body []byte, sum []byte) (bool, error) {
|
||||
sha := sha256.New()
|
||||
_, err := sha.Write(body)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(sum, sha.Sum(nil)), nil
|
||||
}
|
60
server/option.go
Normal file
60
server/option.go
Normal file
@ -0,0 +1,60 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
peering "forge.cadoles.com/wpetit/go-http-peering"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
PeerAttributes []string
|
||||
ErrorHandler ErrorHandler
|
||||
PeerIDValidator func(peering.PeerID) bool
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
|
||||
type ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
|
||||
func WithPeerAttributes(attrs ...string) OptionFunc {
|
||||
return func(options *Options) {
|
||||
options.PeerAttributes = attrs
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger Logger) OptionFunc {
|
||||
return func(options *Options) {
|
||||
options.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
func WithErrorHandler(handler ErrorHandler) OptionFunc {
|
||||
return func(options *Options) {
|
||||
options.ErrorHandler = handler
|
||||
}
|
||||
}
|
||||
|
||||
func defaultOptions() *Options {
|
||||
logger := log.New(os.Stdout, "[go-http-peering] ", log.LstdFlags|log.Lshortfile)
|
||||
return &Options{
|
||||
PeerAttributes: []string{"Label"},
|
||||
ErrorHandler: DefaultErrorHandler,
|
||||
PeerIDValidator: DefaultPeerIDValidator,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func createOptions(funcs ...OptionFunc) *Options {
|
||||
options := defaultOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(options)
|
||||
}
|
||||
return options
|
||||
}
|
14
server/util_test.go
Normal file
14
server/util_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
func mustGeneratePrivateKey() *rsa.PrivateKey {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return privateKey
|
||||
}
|
Reference in New Issue
Block a user