Initial commit
This commit is contained in:
32
middleware/debug.go
Normal file
32
middleware/debug.go
Normal file
@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/go-chi/chi/middleware"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyDebug is the context key associated with the debug value
|
||||
KeyDebug ContextKey = "debug"
|
||||
)
|
||||
|
||||
// ErrInvalidDebug is returned when no debug value
|
||||
// could be found on the given context
|
||||
var ErrInvalidDebug = errors.New("invalid debug")
|
||||
|
||||
// GetDebug retrieves the debug value from the given context
|
||||
func GetDebug(ctx context.Context) (bool, error) {
|
||||
debug, ok := ctx.Value(KeyDebug).(bool)
|
||||
if !ok {
|
||||
return false, ErrInvalidDebug
|
||||
}
|
||||
return debug, nil
|
||||
}
|
||||
|
||||
// Debug expose the given debug flag as a context value
|
||||
// on the HTTP requests
|
||||
func Debug(debug bool) Middleware {
|
||||
return middleware.WithValue(KeyDebug, debug)
|
||||
}
|
22
middleware/debug_test.go
Normal file
22
middleware/debug_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContextDebug(t *testing.T) {
|
||||
|
||||
debug := false
|
||||
ctx := context.WithValue(context.Background(), KeyDebug, debug)
|
||||
|
||||
dbg, err := GetDebug(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if dbg {
|
||||
t.Fatal("debug should be false")
|
||||
}
|
||||
|
||||
}
|
28
middleware/invalid_host.go
Normal file
28
middleware/invalid_host.go
Normal file
@ -0,0 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// InvalidHostRedirect returns a middleware that redirects incoming
|
||||
// requests that do not use the expected host/schemes
|
||||
func InvalidHostRedirect(expectedHost string, useTLS bool) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
invalidScheme := (useTLS && r.TLS == nil) || (!useTLS && r.TLS != nil)
|
||||
invalidHost := expectedHost != r.Host
|
||||
if invalidHost || invalidScheme {
|
||||
if expectedHost != "" {
|
||||
r.URL.Host = expectedHost
|
||||
}
|
||||
if useTLS && r.TLS == nil {
|
||||
r.URL.Scheme = "https"
|
||||
} else if !useTLS && r.TLS != nil {
|
||||
r.URL.Scheme = "http"
|
||||
}
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
33
middleware/service_container.go
Normal file
33
middleware/service_container.go
Normal file
@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"forge.cadoles.com/wpetit/goweb/service"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyServiceContainer is the context key associated with the ServiceContainer value
|
||||
KeyServiceContainer ContextKey = "serviceContainer"
|
||||
)
|
||||
|
||||
// ErrInvalidServiceContainer is returned when no service container
|
||||
// could be found on the given context
|
||||
var ErrInvalidServiceContainer = errors.New("invalid service container")
|
||||
|
||||
// GetServiceContainer retrieves the service container from the given context
|
||||
func GetServiceContainer(ctx context.Context) (*service.Container, error) {
|
||||
container, ok := ctx.Value(KeyServiceContainer).(*service.Container)
|
||||
if !ok {
|
||||
return nil, ErrInvalidServiceContainer
|
||||
}
|
||||
return container, nil
|
||||
}
|
||||
|
||||
// ServiceContainer expose the given service container as a context value
|
||||
// on the HTTP requests
|
||||
func ServiceContainer(container *service.Container) Middleware {
|
||||
return middleware.WithValue(KeyServiceContainer, container)
|
||||
}
|
41
middleware/service_container_test.go
Normal file
41
middleware/service_container_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"forge.cadoles.com/wpetit/goweb/service"
|
||||
)
|
||||
|
||||
func TestContextServiceContainer(t *testing.T) {
|
||||
|
||||
container := service.NewContainer()
|
||||
ctx := context.WithValue(context.Background(), KeyServiceContainer, container)
|
||||
|
||||
ctn, err := GetServiceContainer(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ctn == nil {
|
||||
t.Fatal("container should not be nil")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestContextInvalidServiceContainer(t *testing.T) {
|
||||
|
||||
invalidContainer := struct{}{}
|
||||
ctx := context.WithValue(context.Background(), KeyServiceContainer, invalidContainer)
|
||||
|
||||
container, err := GetServiceContainer(ctx)
|
||||
|
||||
if g, e := err, ErrInvalidServiceContainer; g != e {
|
||||
t.Errorf("err: got '%v', expected '%v'", g, e)
|
||||
}
|
||||
|
||||
if container != nil {
|
||||
t.Errorf("container: got '%v', expected '%v'", container, nil)
|
||||
}
|
||||
|
||||
}
|
9
middleware/type.go
Normal file
9
middleware/type.go
Normal file
@ -0,0 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ContextKey are values exposed on the request context
|
||||
type ContextKey string
|
||||
|
||||
// Middleware An HTTP middleware
|
||||
type Middleware func(http.Handler) http.Handler
|
Reference in New Issue
Block a user