From 2390ed2d15c55bb140b8c55a6f42fcc0023b09eb Mon Sep 17 00:00:00 2001 From: William Petit Date: Fri, 16 Oct 2020 17:27:44 +0200 Subject: [PATCH] initial commit --- README.md | 7 ++++ error.go | 7 ++++ go.mod | 5 +++ go.sum | 2 + identifier.go | 17 ++++++++ liar.go | 49 +++++++++++++++++++++++ matcher.go | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++ middleware.go | 52 +++++++++++++++++++++++++ option.go | 17 ++++++++ registry.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++ service.go | 76 ++++++++++++++++++++++++++++++++++++ 11 files changed, 438 insertions(+) create mode 100644 README.md create mode 100644 error.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 identifier.go create mode 100644 liar.go create mode 100644 matcher.go create mode 100644 middleware.go create mode 100644 option.go create mode 100644 registry.go create mode 100644 service.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..91dc9cf --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# go-captiveportal + +Librairie utilitaire permettant d'implémenter des fonctionnalités de portail captif en Go. + +## Licence + +AGPL-3.0 \ No newline at end of file diff --git a/error.go b/error.go new file mode 100644 index 0000000..081c952 --- /dev/null +++ b/error.go @@ -0,0 +1,7 @@ +package captiveportal + +import "errors" + +var ( + ErrClientIdentificationFailed = errors.New("client identification failed") +) diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7f0c366 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module forge.cadoles.com/wpetit/go-captiveportal + +go 1.15 + +require github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/identifier.go b/identifier.go new file mode 100644 index 0000000..16ce9cf --- /dev/null +++ b/identifier.go @@ -0,0 +1,17 @@ +package captiveportal + +import "net/http" + +type Identifier interface { + Identify(r *http.Request) (string, error) +} + +type IdentifierFunc func(r *http.Request) (string, error) + +func (f IdentifierFunc) Identify(r *http.Request) (string, error) { + return f(r) +} + +func DefaultIdentifier(r *http.Request) (string, error) { + return "", nil +} diff --git a/liar.go b/liar.go new file mode 100644 index 0000000..c75ec88 --- /dev/null +++ b/liar.go @@ -0,0 +1,49 @@ +package captiveportal + +import ( + "net/http" +) + +type Liar interface { + Handle(os OS, w http.ResponseWriter, r *http.Request) +} + +type LiarFunc func(os OS, w http.ResponseWriter, r *http.Request) + +func (f LiarFunc) Handle(os OS, w http.ResponseWriter, r *http.Request) { + f(os, w, r) +} + +func HandleAndroid(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +} + +func HandleApple(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // nolint: errcheck + w.Write([]byte("SuccessSuccess")) +} + +func HandleWindows(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // nolint: errcheck + w.Write([]byte("Microsoft NCSI")) +} + +// nolint: gochecknoglobals +var defaultLiars = map[OS]http.HandlerFunc{ + OSAndroid: http.HandlerFunc(HandleAndroid), + OSApple: http.HandlerFunc(HandleApple), + OSWindows: http.HandlerFunc(HandleWindows), +} + +func DefaultLiar(os OS, w http.ResponseWriter, r *http.Request) { + liar, exists := defaultLiars[os] + if !exists { + w.WriteHeader(http.StatusNoContent) + + return + } + + liar.ServeHTTP(w, r) +} diff --git a/matcher.go b/matcher.go new file mode 100644 index 0000000..90b2aea --- /dev/null +++ b/matcher.go @@ -0,0 +1,106 @@ +package captiveportal + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" +) + +type OS int + +const ( + OSUnknown OS = iota + OSAndroid + OSApple + OSWindows + OSLinux +) + +type Matcher interface { + Match(r *http.Request) (bool, OS, error) +} + +type MatchFunc func(r *http.Request) (bool, OS, error) + +func (f MatchFunc) Match(r *http.Request) (bool, OS, error) { + return f(r) +} + +func MatchAndroid(r *http.Request) (bool, error) { + // Samples + // + // https://www.google.com/generate_204 + // http://connectivitycheck.gstatic.com/generate_204 + // http://www.google.com/gen_204 + // http://play.googleapis.com/generate_204 + // http://connectivitycheck.gstatic.com/generate_204 + // http://clients3.google.com/generate_204 + // http://g.cn + matches := strings.Contains(r.URL.Path, "generate_204") || + strings.Contains(r.URL.Path, "gen_204") || + r.URL.Hostname() == "g.cn" + + return matches, nil +} + +func MatchApple(r *http.Request) (bool, error) { + // Samples + // + // http://www.apple.com/library/test/success.html + // http://captive.apple.com/hotspot-detect.html + hostname := r.URL.Hostname() + + matches := hostname == "www.apple.com" || + hostname == "captive.apple.com" + + return matches, nil +} + +func MatchLinux(r *http.Request) (bool, error) { + // Samples + // + // http://start.ubuntu.com/connectivity-check.html + // http://nmcheck.gnome.org/check_network_status.txt + hostname := r.URL.Hostname() + + matches := hostname == "start.ubuntu.com" || + hostname == "nmcheck.gnome.org" + + return matches, nil +} + +func MatchWindows(r *http.Request) (bool, error) { + // Samples + // + // http://www.msftncsi.com + // http://www.msftncsi.com/ncsi.txt + hostname := r.URL.Hostname() + + matches := hostname == "www.msftncsi.com" + + return matches, nil +} + +// nolint: gochecknoglobals +var defaultMatchers = map[OS]func(r *http.Request) (bool, error){ + OSAndroid: MatchAndroid, + OSApple: MatchApple, + OSWindows: MatchWindows, + OSLinux: MatchLinux, +} + +func DefaultMatch(r *http.Request) (bool, OS, error) { + for os, match := range defaultMatchers { + matches, err := match(r) + if err != nil { + return false, OSUnknown, errors.WithStack(err) + } + + if matches { + return matches, os, nil + } + } + + return false, OSUnknown, nil +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..6dfab48 --- /dev/null +++ b/middleware.go @@ -0,0 +1,52 @@ +package captiveportal + +import ( + "net/http" + + "github.com/pkg/errors" +) + +func (s *Service) Middleware() func(next http.Handler) http.Handler { + registry := s.registry + matcher := s.options.Matcher + liar := s.options.Liar + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + matches, os, err := matcher.Match(r) + if err != nil { + panic(errors.WithStack(err)) + } + + if !matches { + next.ServeHTTP(w, r) + + return + } + + id, err := s.options.Identifier.Identify(r) + if err != nil { + panic(errors.Wrap(err, ErrClientIdentificationFailed.Error())) + } + + registry.Touch(id, os) + + if registry.IsLying(id) { + liar.Handle(os, w, r) + + return + } + + if registry.IsCaptive(id) { + // Redirect to configured URL + http.Redirect(w, r, s.captivePortalURL, http.StatusTemporaryRedirect) + + return + } + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} diff --git a/option.go b/option.go new file mode 100644 index 0000000..67e9842 --- /dev/null +++ b/option.go @@ -0,0 +1,17 @@ +package captiveportal + +type Options struct { + Identifier Identifier + Matcher Matcher + Liar Liar +} + +type OptionsFunc func(*Options) + +func DefaultOptions() *Options { + return &Options{ + Identifier: IdentifierFunc(DefaultIdentifier), + Matcher: MatchFunc(DefaultMatch), + Liar: LiarFunc(DefaultLiar), + } +} diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..3e4af3e --- /dev/null +++ b/registry.go @@ -0,0 +1,100 @@ +package captiveportal + +import ( + "sync" + "time" +) + +type Client struct { + LastSeen time.Time + Captive bool + Lying bool + OS OS +} + +type Registry struct { + mutex sync.RWMutex + clients map[string]*Client +} + +func (r *Registry) Touch(id string, os OS) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.LastSeen = time.Now() + client.OS = os +} + +func (r *Registry) Lie(id string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.Lying = true +} + +func (r *Registry) Release(id string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + client := r.upsert(id) + client.Captive = false +} + +func (r *Registry) IsCaptive(id string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return false + } + + return client.Captive +} + +func (r *Registry) ClientOS(id string) OS { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return OSUnknown + } + + return client.OS +} + +func (r *Registry) IsLying(id string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + client, exists := r.clients[id] + if !exists { + return false + } + + return client.Captive +} + +func (r *Registry) upsert(id string) *Client { + client, exists := r.clients[id] + if !exists { + client = &Client{ + Captive: true, + Lying: false, + OS: OSUnknown, + } + + r.clients[id] = client + } + + return client +} + +func NewRegistry() *Registry { + return &Registry{ + clients: make(map[string]*Client), + } +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..67d9e98 --- /dev/null +++ b/service.go @@ -0,0 +1,76 @@ +package captiveportal + +import ( + "net/http" + + "github.com/pkg/errors" +) + +type Service struct { + captivePortalURL string + options *Options + registry *Registry +} + +func (s *Service) ClientID(r *http.Request) (string, error) { + id, err := s.options.Identifier.Identify(r) + if err != nil { + return "", errors.Wrap(err, ErrClientIdentificationFailed.Error()) + } + + return id, nil +} + +func (s *Service) IsCaptive(r *http.Request) (bool, error) { + id, err := s.ClientID(r) + if err != nil { + return false, errors.WithStack(err) + } + + return s.registry.IsCaptive(id), nil +} + +func (s *Service) Release(r *http.Request) error { + id, err := s.ClientID(r) + if err != nil { + return errors.WithStack(err) + } + + s.registry.Release(id) + + return nil +} + +func (s *Service) Lie(r *http.Request) error { + id, err := s.ClientID(r) + if err != nil { + return errors.WithStack(err) + } + + s.registry.Lie(id) + + return nil +} + +func (s *Service) ClientOS(r *http.Request) (OS, error) { + id, err := s.ClientID(r) + if err != nil { + return OSUnknown, errors.WithStack(err) + } + + return s.registry.ClientOS(id), nil +} + +func New(captivePortalURL string, opts ...OptionsFunc) *Service { + options := DefaultOptions() + + for _, o := range opts { + o(options) + } + + return &Service{ + captivePortalURL: captivePortalURL, + options: options, + registry: NewRegistry(), + } +}