commit 2390ed2d15c55bb140b8c55a6f42fcc0023b09eb Author: William Petit Date: Fri Oct 16 17:27:44 2020 +0200 initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..91dc9cf --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# go-captiveportal + +Librairie utilitaire permettant d'implémenter des fonctionnalités de portail captif en Go. + +## Licence + +AGPL-3.0 \ No newline at end of file diff --git a/error.go b/error.go new file mode 100644 index 0000000..081c952 --- /dev/null +++ b/error.go @@ -0,0 +1,7 @@ +package captiveportal + +import "errors" + +var ( + ErrClientIdentificationFailed = errors.New("client identification failed") +) diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7f0c366 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module forge.cadoles.com/wpetit/go-captiveportal + +go 1.15 + +require github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/identifier.go b/identifier.go new file mode 100644 index 0000000..16ce9cf --- /dev/null +++ b/identifier.go @@ -0,0 +1,17 @@ +package captiveportal + +import "net/http" + +type Identifier interface { + Identify(r *http.Request) (string, error) +} + +type IdentifierFunc func(r *http.Request) (string, error) + +func (f IdentifierFunc) Identify(r *http.Request) (string, error) { + return f(r) +} + +func DefaultIdentifier(r *http.Request) (string, error) { + return "", nil +} diff --git a/liar.go b/liar.go new file mode 100644 index 0000000..c75ec88 --- /dev/null +++ b/liar.go @@ -0,0 +1,49 @@ +package captiveportal + +import ( + "net/http" +) + +type Liar interface { + Handle(os OS, w http.ResponseWriter, r *http.Request) +} + +type LiarFunc func(os OS, w http.ResponseWriter, r *http.Request) + +func (f LiarFunc) Handle(os OS, w http.ResponseWriter, r *http.Request) { + f(os, w, r) +} + +func HandleAndroid(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +} + +func HandleApple(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // nolint: errcheck + w.Write([]byte("SuccessSuccess")) +} + +func HandleWindows(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // nolint: errcheck + w.Write([]byte("Microsoft NCSI")) +} + +// nolint: gochecknoglobals +var defaultLiars = map[OS]http.HandlerFunc{ + OSAndroid: http.HandlerFunc(HandleAndroid), + OSApple: http.HandlerFunc(HandleApple), + OSWindows: http.HandlerFunc(HandleWindows), +} + +func DefaultLiar(os OS, w http.ResponseWriter, r *http.Request) { + liar, exists := defaultLiars[os] + if !exists { + w.WriteHeader(http.StatusNoContent) + + return + } + + liar.ServeHTTP(w, r) +} diff --git a/matcher.go b/matcher.go new file mode 100644 index 0000000..90b2aea --- /dev/null +++ b/matcher.go @@ -0,0 +1,106 @@ +package captiveportal + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" +) + +type OS int + +const ( + OSUnknown OS = iota + OSAndroid + OSApple + OSWindows + OSLinux +) + +type Matcher interface { + Match(r *http.Request) (bool, OS, error) +} + +type MatchFunc func(r *http.Request) (bool, OS, error) + +func (f MatchFunc) Match(r *http.Request) (bool, OS, error) { + return f(r) +} + +func MatchAndroid(r *http.Request) (bool, error) { + // Samples + // + // https://www.google.com/generate_204 + // http://connectivitycheck.gstatic.com/generate_204 + // http://www.google.com/gen_204 + // http://play.googleapis.com/generate_204 + // http://connectivitycheck.gstatic.com/generate_204 + // http://clients3.google.com/generate_204 + // http://g.cn + matches := strings.Contains(r.URL.Path, "generate_204") || + strings.Contains(r.URL.Path, "gen_204") || + r.URL.Hostname() == "g.cn" + + return matches, nil +} + +func MatchApple(r *http.Request) (bool, error) { + // Samples + // + // http://www.apple.com/library/test/success.html + // http://captive.apple.com/hotspot-detect.html + hostname := r.URL.Hostname() + + matches := hostname == "www.apple.com" || + hostname == "captive.apple.com" + + return matches, nil +} + +func MatchLinux(r *http.Request) (bool, error) { + // Samples + // + // http://start.ubuntu.com/connectivity-check.html + // http://nmcheck.gnome.org/check_network_status.txt + hostname := r.URL.Hostname() + + matches := hostname == "start.ubuntu.com" || + hostname == "nmcheck.gnome.org" + + return matches, nil +} + +func MatchWindows(r *http.Request) (bool, error) { + // Samples + // + // http://www.msftncsi.com + // http://www.msftncsi.com/ncsi.txt + hostname := r.URL.Hostname() + + matches := hostname == "www.msftncsi.com" + + return matches, nil +} + +// nolint: gochecknoglobals +var defaultMatchers = map[OS]func(r *http.Request) (bool, error){ + OSAndroid: MatchAndroid, + OSApple: MatchApple, + OSWindows: MatchWindows, + OSLinux: MatchLinux, +} + +func DefaultMatch(r *http.Request) (bool, OS, error) { + for os, match := range defaultMatchers { + matches, err := match(r) + if err != nil { + return false, OSUnknown, errors.WithStack(err) + } + + if matches { + return matches, os, nil + } + } + + return false, OSUnknown, nil +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..6dfab48 --- /dev/null +++ b/middleware.go @@ -0,0 +1,52 @@ +package captiveportal + +import ( + "net/http" + + "github.com/pkg/errors" +) + +func (s *Service) Middleware() func(next http.Handler) http.Handler { + registry := s.registry + matcher := s.options.Matcher + liar := s.options.Liar + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + matches, os, err := matcher.Match(r) + if err != nil { + panic(errors.WithStack(err)) + } + + if !matches { + next.ServeHTTP(w, r) + + return + } + + id, err := s.options.Identifier.Identify(r) + if err != nil { + panic(errors.Wrap(err, ErrClientIdentificationFailed.Error())) + } + + registry.Touch(id, os) + + if registry.IsLying(id) { + liar.Handle(os, w, r) + + return + } + + if registry.IsCaptive(id) { + // Redirect to configured URL + http.Redirect(w, r, s.captivePortalURL, http.StatusTemporaryRedirect) + + return + } + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} diff --git a/option.go b/option.go new file mode 100644 index 0000000..67e9842 --- /dev/null +++ b/option.go @@ -0,0 +1,17 @@ +package captiveportal + +type Options struct { + Identifier Identifier + Matcher Matcher + Liar Liar +} + +type OptionsFunc func(*Options) + +func DefaultOptions() *Options { + return &Options{ + Identifier: IdentifierFunc(DefaultIdentifier), + Matcher: MatchFunc(DefaultMatch), + Liar: LiarFunc(DefaultLiar), + } +} diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..3e4af3e --- /dev/null +++ b/registry.go @@ -0,0 +1,100 @@ +package captiveportal + +import ( + "sync" + "time" +) + +type Client struct { + LastSeen time.Time + Captive bool + Lying bool + OS OS +} + +type Registry struct { + mutex sync.RWMutex + clients map[string]*Client +} + +func (r *Registry) Touch(id string, os OS) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.LastSeen = time.Now() + client.OS = os +} + +func (r *Registry) Lie(id string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.Lying = true +} + +func (r *Registry) Release(id string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.Captive = false +} + +func (r *Registry) IsCaptive(id string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return false + } + + return client.Captive +} + +func (r *Registry) ClientOS(id string) OS { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return OSUnknown + } + + return client.OS +} + +func (r *Registry) IsLying(id string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return false + } + + return client.Captive +} + +func (r *Registry) upsert(id string) *Client { + client, exists := r.clients[id] + if !exists { + client = &Client{ + Captive: true, + Lying: false, + OS: OSUnknown, + } + + r.clients[id] = client + } + + return client +} + +func NewRegistry() *Registry { + return &Registry{ + clients: make(map[string]*Client), + } +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..67d9e98 --- /dev/null +++ b/service.go @@ -0,0 +1,76 @@ +package captiveportal + +import ( + "net/http" + + "github.com/pkg/errors" +) + +type Service struct { + captivePortalURL string + options *Options + registry *Registry +} + +func (s *Service) ClientID(r *http.Request) (string, error) { + id, err := s.options.Identifier.Identify(r) + if err != nil { + return "", errors.Wrap(err, ErrClientIdentificationFailed.Error()) + } + + return id, nil +} + +func (s *Service) IsCaptive(r *http.Request) (bool, error) { + id, err := s.ClientID(r) + if err != nil { + return false, errors.WithStack(err) + } + + return s.registry.IsCaptive(id), nil +} + +func (s *Service) Release(r *http.Request) error { + id, err := s.ClientID(r) + if err != nil { + return errors.WithStack(err) + } + + s.registry.Release(id) + + return nil +} + +func (s *Service) Lie(r *http.Request) error { + id, err := s.ClientID(r) + if err != nil { + return errors.WithStack(err) + } + + s.registry.Lie(id) + + return nil +} + +func (s *Service) ClientOS(r *http.Request) (OS, error) { + id, err := s.ClientID(r) + if err != nil { + return OSUnknown, errors.WithStack(err) + } + + return s.registry.ClientOS(id), nil +} + +func New(captivePortalURL string, opts ...OptionsFunc) *Service { + options := DefaultOptions() + + for _, o := range opts { + o(options) + } + + return &Service{ + captivePortalURL: captivePortalURL, + options: options, + registry: NewRegistry(), + } +}