initial commit

This commit is contained in:
wpetit 2020-10-16 17:27:44 +02:00
commit 2390ed2d15
11 changed files with 438 additions and 0 deletions

7
README.md Normal file
View File

@ -0,0 +1,7 @@
# go-captiveportal
Librairie utilitaire permettant d'implémenter des fonctionnalités de portail captif en Go.
## Licence
AGPL-3.0

7
error.go Normal file
View File

@ -0,0 +1,7 @@
package captiveportal
import "errors"
var (
ErrClientIdentificationFailed = errors.New("client identification failed")
)

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module forge.cadoles.com/wpetit/go-captiveportal
go 1.15
require github.com/pkg/errors v0.9.1

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

17
identifier.go Normal file
View File

@ -0,0 +1,17 @@
package captiveportal
import "net/http"
type Identifier interface {
Identify(r *http.Request) (string, error)
}
type IdentifierFunc func(r *http.Request) (string, error)
func (f IdentifierFunc) Identify(r *http.Request) (string, error) {
return f(r)
}
func DefaultIdentifier(r *http.Request) (string, error) {
return "", nil
}

49
liar.go Normal file
View File

@ -0,0 +1,49 @@
package captiveportal
import (
"net/http"
)
type Liar interface {
Handle(os OS, w http.ResponseWriter, r *http.Request)
}
type LiarFunc func(os OS, w http.ResponseWriter, r *http.Request)
func (f LiarFunc) Handle(os OS, w http.ResponseWriter, r *http.Request) {
f(os, w, r)
}
func HandleAndroid(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
func HandleApple(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// nolint: errcheck
w.Write([]byte("<HTML><HEAD><TITLE>Success</TITLE></HEAD><BODY>Success</BODY></HTML>"))
}
func HandleWindows(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// nolint: errcheck
w.Write([]byte("Microsoft NCSI"))
}
// nolint: gochecknoglobals
var defaultLiars = map[OS]http.HandlerFunc{
OSAndroid: http.HandlerFunc(HandleAndroid),
OSApple: http.HandlerFunc(HandleApple),
OSWindows: http.HandlerFunc(HandleWindows),
}
func DefaultLiar(os OS, w http.ResponseWriter, r *http.Request) {
liar, exists := defaultLiars[os]
if !exists {
w.WriteHeader(http.StatusNoContent)
return
}
liar.ServeHTTP(w, r)
}

106
matcher.go Normal file
View File

@ -0,0 +1,106 @@
package captiveportal
import (
"net/http"
"strings"
"github.com/pkg/errors"
)
type OS int
const (
OSUnknown OS = iota
OSAndroid
OSApple
OSWindows
OSLinux
)
type Matcher interface {
Match(r *http.Request) (bool, OS, error)
}
type MatchFunc func(r *http.Request) (bool, OS, error)
func (f MatchFunc) Match(r *http.Request) (bool, OS, error) {
return f(r)
}
func MatchAndroid(r *http.Request) (bool, error) {
// Samples
//
// https://www.google.com/generate_204
// http://connectivitycheck.gstatic.com/generate_204
// http://www.google.com/gen_204
// http://play.googleapis.com/generate_204
// http://connectivitycheck.gstatic.com/generate_204
// http://clients3.google.com/generate_204
// http://g.cn
matches := strings.Contains(r.URL.Path, "generate_204") ||
strings.Contains(r.URL.Path, "gen_204") ||
r.URL.Hostname() == "g.cn"
return matches, nil
}
func MatchApple(r *http.Request) (bool, error) {
// Samples
//
// http://www.apple.com/library/test/success.html
// http://captive.apple.com/hotspot-detect.html
hostname := r.URL.Hostname()
matches := hostname == "www.apple.com" ||
hostname == "captive.apple.com"
return matches, nil
}
func MatchLinux(r *http.Request) (bool, error) {
// Samples
//
// http://start.ubuntu.com/connectivity-check.html
// http://nmcheck.gnome.org/check_network_status.txt
hostname := r.URL.Hostname()
matches := hostname == "start.ubuntu.com" ||
hostname == "nmcheck.gnome.org"
return matches, nil
}
func MatchWindows(r *http.Request) (bool, error) {
// Samples
//
// http://www.msftncsi.com
// http://www.msftncsi.com/ncsi.txt
hostname := r.URL.Hostname()
matches := hostname == "www.msftncsi.com"
return matches, nil
}
// nolint: gochecknoglobals
var defaultMatchers = map[OS]func(r *http.Request) (bool, error){
OSAndroid: MatchAndroid,
OSApple: MatchApple,
OSWindows: MatchWindows,
OSLinux: MatchLinux,
}
func DefaultMatch(r *http.Request) (bool, OS, error) {
for os, match := range defaultMatchers {
matches, err := match(r)
if err != nil {
return false, OSUnknown, errors.WithStack(err)
}
if matches {
return matches, os, nil
}
}
return false, OSUnknown, nil
}

52
middleware.go Normal file
View File

@ -0,0 +1,52 @@
package captiveportal
import (
"net/http"
"github.com/pkg/errors"
)
func (s *Service) Middleware() func(next http.Handler) http.Handler {
registry := s.registry
matcher := s.options.Matcher
liar := s.options.Liar
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
matches, os, err := matcher.Match(r)
if err != nil {
panic(errors.WithStack(err))
}
if !matches {
next.ServeHTTP(w, r)
return
}
id, err := s.options.Identifier.Identify(r)
if err != nil {
panic(errors.Wrap(err, ErrClientIdentificationFailed.Error()))
}
registry.Touch(id, os)
if registry.IsLying(id) {
liar.Handle(os, w, r)
return
}
if registry.IsCaptive(id) {
// Redirect to configured URL
http.Redirect(w, r, s.captivePortalURL, http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

17
option.go Normal file
View File

@ -0,0 +1,17 @@
package captiveportal
type Options struct {
Identifier Identifier
Matcher Matcher
Liar Liar
}
type OptionsFunc func(*Options)
func DefaultOptions() *Options {
return &Options{
Identifier: IdentifierFunc(DefaultIdentifier),
Matcher: MatchFunc(DefaultMatch),
Liar: LiarFunc(DefaultLiar),
}
}

100
registry.go Normal file
View File

@ -0,0 +1,100 @@
package captiveportal
import (
"sync"
"time"
)
type Client struct {
LastSeen time.Time
Captive bool
Lying bool
OS OS
}
type Registry struct {
mutex sync.RWMutex
clients map[string]*Client
}
func (r *Registry) Touch(id string, os OS) {
r.mutex.Lock()
defer r.mutex.Unlock()
client := r.upsert(id)
client.LastSeen = time.Now()
client.OS = os
}
func (r *Registry) Lie(id string) {
r.mutex.Lock()
defer r.mutex.Unlock()
client := r.upsert(id)
client.Lying = true
}
func (r *Registry) Release(id string) {
r.mutex.Lock()
defer r.mutex.Unlock()
client := r.upsert(id)
client.Captive = false
}
func (r *Registry) IsCaptive(id string) bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
client, exists := r.clients[id]
if !exists {
return false
}
return client.Captive
}
func (r *Registry) ClientOS(id string) OS {
r.mutex.RLock()
defer r.mutex.RUnlock()
client, exists := r.clients[id]
if !exists {
return OSUnknown
}
return client.OS
}
func (r *Registry) IsLying(id string) bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
client, exists := r.clients[id]
if !exists {
return false
}
return client.Captive
}
func (r *Registry) upsert(id string) *Client {
client, exists := r.clients[id]
if !exists {
client = &Client{
Captive: true,
Lying: false,
OS: OSUnknown,
}
r.clients[id] = client
}
return client
}
func NewRegistry() *Registry {
return &Registry{
clients: make(map[string]*Client),
}
}

76
service.go Normal file
View File

@ -0,0 +1,76 @@
package captiveportal
import (
"net/http"
"github.com/pkg/errors"
)
type Service struct {
captivePortalURL string
options *Options
registry *Registry
}
func (s *Service) ClientID(r *http.Request) (string, error) {
id, err := s.options.Identifier.Identify(r)
if err != nil {
return "", errors.Wrap(err, ErrClientIdentificationFailed.Error())
}
return id, nil
}
func (s *Service) IsCaptive(r *http.Request) (bool, error) {
id, err := s.ClientID(r)
if err != nil {
return false, errors.WithStack(err)
}
return s.registry.IsCaptive(id), nil
}
func (s *Service) Release(r *http.Request) error {
id, err := s.ClientID(r)
if err != nil {
return errors.WithStack(err)
}
s.registry.Release(id)
return nil
}
func (s *Service) Lie(r *http.Request) error {
id, err := s.ClientID(r)
if err != nil {
return errors.WithStack(err)
}
s.registry.Lie(id)
return nil
}
func (s *Service) ClientOS(r *http.Request) (OS, error) {
id, err := s.ClientID(r)
if err != nil {
return OSUnknown, errors.WithStack(err)
}
return s.registry.ClientOS(id), nil
}
func New(captivePortalURL string, opts ...OptionsFunc) *Service {
options := DefaultOptions()
for _, o := range opts {
o(options)
}
return &Service{
captivePortalURL: captivePortalURL,
options: options,
registry: NewRegistry(),
}
}